Skip to content

Commit 04f2d36

Browse files
committed
[WIP] Make SVD more general to accommodate graded arrays
1 parent 9eb742b commit 04f2d36

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,22 @@ using BlockArrays:
99
BlockSlice,
1010
BlockVector,
1111
block,
12+
blockedrange,
1213
blockindex,
14+
blocklengths,
1315
findblock,
1416
findblockindex,
1517
mortar
1618

19+
function blockaxeses(a::AbstractUnitRange)
20+
Base.require_one_based_indexing(a)
21+
return map(Base.OneTo, blocklengths(a))
22+
end
23+
24+
function mortar_axis(axeses)
25+
return blockedrange(length.(axeses))
26+
end
27+
1728
# Custom `BlockedUnitRange` constructor that takes a unit range
1829
# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`.
1930
function blockedunitrange(a::AbstractUnitRange, blocklengths)

src/factorizations/svd.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,16 @@ end
2525
# the blocktype and element type - something like S = similar(A, BlockType(...))
2626
function _similar_S(A::AbstractBlockSparseMatrix, s_axis)
2727
T = real(eltype(A))
28-
return BlockSparseArray{T,2,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
28+
return BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (s_axis, s_axis))
29+
end
30+
31+
function similar_output(
32+
::typeof(svd_compact!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.Algorithm
33+
)
34+
U = similar(A, axes(A, 1), s_axis)
35+
S = _similar_S(A, s_axis)
36+
Vt = similar(A, s_axis, axes(A, 2))
37+
return U, S, Vt
2938
end
3039

3140
function MatrixAlgebraKit.initialize_output(
@@ -34,9 +43,9 @@ function MatrixAlgebraKit.initialize_output(
3443
bm, bn = blocksize(A)
3544
bmn = min(bm, bn)
3645

37-
brows = blocklengths(axes(A, 1))
38-
bcols = blocklengths(axes(A, 2))
39-
slengths = Vector{Int}(undef, bmn)
46+
brows = blockaxeses(axes(A, 1))
47+
bcols = blockaxeses(axes(A, 2))
48+
s_axeses = Vector{eltype(brows)}(undef, bmn)
4049

4150
# fill in values for blocks that are present
4251
bIs = collect(eachblockstoredindex(A))
@@ -46,21 +55,19 @@ function MatrixAlgebraKit.initialize_output(
4655
row, col = Int.(Tuple(bI))
4756
nrows = brows[row]
4857
ncols = bcols[col]
49-
slengths[col] = min(nrows, ncols)
58+
s_axeses[col] = min(nrows, ncols)
5059
end
5160

5261
# fill in values for blocks that aren't present, pairing them in order of occurence
5362
# this is a convention, which at least gives the expected results for blockdiagonal
5463
emptyrows = setdiff(1:bm, browIs)
5564
emptycols = setdiff(1:bn, bcolIs)
5665
for (row, col) in zip(emptyrows, emptycols)
57-
slengths[col] = min(brows[row], bcols[col])
66+
s_axeses[col] = min(brows[row], bcols[col])
5867
end
5968

60-
s_axis = blockedrange(slengths)
61-
U = similar(A, axes(A, 1), s_axis)
62-
S = _similar_S(A, s_axis)
63-
Vt = similar(A, s_axis, axes(A, 2))
69+
s_axis = mortar_axis(s_axeses)
70+
U, S, Vt = similar_output(svd_compact!, A, s_axis, alg)
6471

6572
# allocate output
6673
for bI in eachblockstoredindex(A)

0 commit comments

Comments
 (0)