25
25
# the blocktype and element type - something like S = similar(A, BlockType(...))
26
26
function _similar_S (A:: AbstractBlockSparseMatrix , s_axis)
27
27
T = real (eltype (A))
28
- return BlockSparseArray {T,2,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
28
+ return BlockSparseMatrix {T,Diagonal{T,Vector{T}}} (undef, (s_axis, s_axis))
29
+ end
30
+
31
+ function similar_output (
32
+ :: typeof (svd_compact!), A, s_axis:: AbstractUnitRange , alg:: MatrixAlgebraKit.Algorithm
33
+ )
34
+ U = similar (A, axes (A, 1 ), s_axis)
35
+ S = _similar_S (A, s_axis)
36
+ Vt = similar (A, s_axis, axes (A, 2 ))
37
+ return U, S, Vt
29
38
end
30
39
31
40
function MatrixAlgebraKit. initialize_output (
@@ -34,9 +43,9 @@ function MatrixAlgebraKit.initialize_output(
34
43
bm, bn = blocksize (A)
35
44
bmn = min (bm, bn)
36
45
37
- brows = blocklengths (axes (A, 1 ))
38
- bcols = blocklengths (axes (A, 2 ))
39
- slengths = Vector {Int } (undef, bmn)
46
+ brows = blockaxeses (axes (A, 1 ))
47
+ bcols = blockaxeses (axes (A, 2 ))
48
+ s_axeses = Vector {eltype(brows) } (undef, bmn)
40
49
41
50
# fill in values for blocks that are present
42
51
bIs = collect (eachblockstoredindex (A))
@@ -46,21 +55,19 @@ function MatrixAlgebraKit.initialize_output(
46
55
row, col = Int .(Tuple (bI))
47
56
nrows = brows[row]
48
57
ncols = bcols[col]
49
- slengths [col] = min (nrows, ncols)
58
+ s_axeses [col] = min (nrows, ncols)
50
59
end
51
60
52
61
# fill in values for blocks that aren't present, pairing them in order of occurence
53
62
# this is a convention, which at least gives the expected results for blockdiagonal
54
63
emptyrows = setdiff (1 : bm, browIs)
55
64
emptycols = setdiff (1 : bn, bcolIs)
56
65
for (row, col) in zip (emptyrows, emptycols)
57
- slengths [col] = min (brows[row], bcols[col])
66
+ s_axeses [col] = min (brows[row], bcols[col])
58
67
end
59
68
60
- s_axis = blockedrange (slengths)
61
- U = similar (A, axes (A, 1 ), s_axis)
62
- S = _similar_S (A, s_axis)
63
- Vt = similar (A, s_axis, axes (A, 2 ))
69
+ s_axis = mortar_axis (s_axeses)
70
+ U, S, Vt = similar_output (svd_compact!, A, s_axis, alg)
64
71
65
72
# allocate output
66
73
for bI in eachblockstoredindex (A)
0 commit comments