Skip to content

Commit 5f4563d

Browse files
committed
Preserve sector information better in slicing
1 parent 4e44292 commit 5f4563d

File tree

3 files changed

+51
-25
lines changed

3 files changed

+51
-25
lines changed

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
6868
end
6969
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
7070

71+
# TODO: Move this to a `BlockArraysExtensions` library.
72+
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndices)
73+
# TODO: Is this a good definition? It ignores `indices.indices`.
74+
return a[indices.blocks]
75+
end
76+
7177
# Generalization of to `BlockArrays._blockslice`:
7278
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L13-L14
7379
# Used by `BlockArrays.unblock`, which is used in `to_indices`

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,11 @@ function blockedunitrange_getindices(
368368
a::AbstractBlockedUnitRange,
369369
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
370370
)
371-
return mortar(map(b -> a[b], blocks(indices)))
371+
blks = map(b -> a[b], blocks(indices))
372+
# Preserve any extra structure in the axes, like a
373+
# Kronecker structure, symmetry sectors, etc.
374+
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
375+
return mortar(blks, (ax,))
372376
end
373377
function blockedunitrange_getindices(
374378
a::AbstractBlockedUnitRange,

test/test_factorizations.jl

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,23 @@ test_params = Iterators.product(blockszs, eltypes)
146146
@test test_svd(a, usv_empty)
147147

148148
# test blockdiagonal
149+
rng = StableRNG(123)
149150
for i in LinearAlgebra.diagind(blocks(a))
150151
I = CartesianIndices(blocks(a))[i]
151-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
152+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
152153
end
153154
usv = svd_compact(a)
154155
@test test_svd(a, usv)
155156

156-
perm = Random.randperm(length(m))
157+
rng = StableRNG(123)
158+
perm = Random.randperm(rng, length(m))
157159
b = a[Block.(perm), Block.(1:length(n))]
158160
usv = svd_compact(b)
159161
@test test_svd(b, usv)
160162

161163
# test permuted blockdiagonal with missing row/col
162-
I_removed = rand(eachblockstoredindex(b))
164+
rng = StableRNG(123)
165+
I_removed = rand(rng, eachblockstoredindex(b))
163166
c = copy(b)
164167
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
165168
usv = svd_compact(c)
@@ -176,20 +179,23 @@ end
176179
@test test_svd(a, usv_empty; full=true)
177180

178181
# test blockdiagonal
182+
rng = StableRNG(123)
179183
for i in LinearAlgebra.diagind(blocks(a))
180184
I = CartesianIndices(blocks(a))[i]
181-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
185+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
182186
end
183187
usv = svd_full(a)
184188
@test test_svd(a, usv; full=true)
185189

186-
perm = Random.randperm(length(m))
190+
rng = StableRNG(123)
191+
perm = Random.randperm(rng, length(m))
187192
b = a[Block.(perm), Block.(1:length(n))]
188193
usv = svd_full(b)
189194
@test test_svd(b, usv; full=true)
190195

191196
# test permuted blockdiagonal with missing row/col
192-
I_removed = rand(eachblockstoredindex(b))
197+
rng = StableRNG(123)
198+
I_removed = rand(rng, eachblockstoredindex(b))
193199
c = copy(b)
194200
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
195201
usv = svd_full(c)
@@ -203,9 +209,10 @@ end
203209
a = BlockSparseArray{T}(undef, m, n)
204210

205211
# test blockdiagonal
212+
rng = StableRNG(123)
206213
for i in LinearAlgebra.diagind(blocks(a))
207214
I = CartesianIndices(blocks(a))[i]
208-
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
215+
a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i]))
209216
end
210217

211218
minmn = min(size(a)...)
@@ -236,7 +243,8 @@ end
236243
@test (V1ᴴ * V1ᴴ' LinearAlgebra.I)
237244

238245
# test permuted blockdiagonal
239-
perm = Random.randperm(length(m))
246+
rng = StableRNG(123)
247+
perm = Random.randperm(rng, length(m))
240248
b = a[Block.(perm), Block.(1:length(n))]
241249
for trunc in (truncrank(r), trunctol(atol))
242250
U1, S1, V1ᴴ = svd_trunc(b; trunc)
@@ -270,8 +278,9 @@ end
270278
@testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
271279
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
272280
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
273-
A[Block(1, 1)] = randn(T, i, k)
274-
A[Block(2, 2)] = randn(T, j, l)
281+
rng = StableRNG(123)
282+
A[Block(1, 1)] = randn(rng, T, i, k)
283+
A[Block(2, 2)] = randn(rng, T, j, l)
275284
Q, R = qr_compact(A)
276285
@test Matrix(Q'Q) LinearAlgebra.I
277286
@test A Q * R
@@ -281,8 +290,9 @@ end
281290
@testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
282291
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
283292
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
284-
A[Block(1, 1)] = randn(T, i, k)
285-
A[Block(2, 2)] = randn(T, j, l)
293+
rng = StableRNG(123)
294+
A[Block(1, 1)] = randn(rng, T, i, k)
295+
A[Block(2, 2)] = randn(rng, T, j, l)
286296
Q, R = qr_full(A)
287297
Q′, R′ = qr_full(Matrix(A))
288298
@test size(Q) == size(Q′)
@@ -296,8 +306,9 @@ end
296306
@testset "lq_compact" for T in (Float32, Float64, ComplexF32, ComplexF64)
297307
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
298308
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
299-
A[Block(1, 1)] = randn(T, i, k)
300-
A[Block(2, 2)] = randn(T, j, l)
309+
rng = StableRNG(123)
310+
A[Block(1, 1)] = randn(rng, T, i, k)
311+
A[Block(2, 2)] = randn(rng, T, j, l)
301312
L, Q = lq_compact(A)
302313
@test Matrix(Q * Q') LinearAlgebra.I
303314
@test A L * Q
@@ -307,8 +318,9 @@ end
307318
@testset "lq_full" for T in (Float32, Float64, ComplexF32, ComplexF64)
308319
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
309320
A = BlockSparseArray{T}(undef, ([i, j], [k, l]))
310-
A[Block(1, 1)] = randn(T, i, k)
311-
A[Block(2, 2)] = randn(T, j, l)
321+
rng = StableRNG(123)
322+
A[Block(1, 1)] = randn(rng, T, i, k)
323+
A[Block(2, 2)] = randn(rng, T, j, l)
312324
L, Q = lq_full(A)
313325
L′, Q′ = lq_full(Matrix(A))
314326
@test size(L) == size(L′)
@@ -321,8 +333,9 @@ end
321333

322334
@testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
323335
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
324-
A[Block(1, 1)] = randn(T, 3, 2)
325-
A[Block(2, 2)] = randn(T, 4, 3)
336+
rng = StableRNG(123)
337+
A[Block(1, 1)] = randn(rng, T, 3, 2)
338+
A[Block(2, 2)] = randn(rng, T, 4, 3)
326339

327340
U, C = left_polar(A)
328341
@test U * C A
@@ -331,8 +344,9 @@ end
331344

332345
@testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
333346
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
334-
A[Block(1, 1)] = randn(T, 2, 3)
335-
A[Block(2, 2)] = randn(T, 3, 4)
347+
rng = StableRNG(123)
348+
A[Block(1, 1)] = randn(rng, T, 2, 3)
349+
A[Block(2, 2)] = randn(rng, T, 3, 4)
336350

337351
C, U = right_polar(A)
338352
@test C * U A
@@ -341,8 +355,9 @@ end
341355

342356
@testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
343357
A = BlockSparseArray{T}(undef, ([3, 4], [2, 3]))
344-
A[Block(1, 1)] = randn(T, 3, 2)
345-
A[Block(2, 2)] = randn(T, 4, 3)
358+
rng = StableRNG(123)
359+
A[Block(1, 1)] = randn(rng, T, 3, 2)
360+
A[Block(2, 2)] = randn(rng, T, 4, 3)
346361

347362
for kind in (:polar, :qr, :svd)
348363
U, C = left_orth(A; kind)
@@ -358,8 +373,9 @@ end
358373

359374
@testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64)
360375
A = BlockSparseArray{T}(undef, ([2, 3], [3, 4]))
361-
A[Block(1, 1)] = randn(T, 2, 3)
362-
A[Block(2, 2)] = randn(T, 3, 4)
376+
rng = StableRNG(123)
377+
A[Block(1, 1)] = randn(rng, T, 2, 3)
378+
A[Block(2, 2)] = randn(rng, T, 3, 4)
363379

364380
for kind in (:lq, :polar, :svd)
365381
C, U = right_orth(A; kind)

0 commit comments

Comments
 (0)