Skip to content

Commit 720d5ec

Browse files
committed
Full SVD
1 parent d4fe33e commit 720d5ec

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/factorizations/svd.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,23 @@ function MatrixAlgebraKit.initialize_output(
8989
return U, S, Vt
9090
end
9191

92+
function similar_output(
93+
::typeof(svd_full!), A, s_axis::AbstractUnitRange, alg::MatrixAlgebraKit.AbstractAlgorithm
94+
)
95+
U = similar(A, axes(A, 1), s_axis)
96+
T = real(eltype(A))
97+
S = similar(A, T, (s_axis, axes(A, 2)))
98+
Vt = similar(A, axes(A, 2), axes(A, 2))
99+
return U, S, Vt
100+
end
101+
92102
function MatrixAlgebraKit.initialize_output(
93103
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
94104
)
95105
bm, bn = blocksize(A)
96106

97-
brows = blocklengths(axes(A, 1))
98-
slengths = copy(brows)
107+
brows = eachblockaxis(axes(A, 1))
108+
s_axes = copy(brows)
99109

100110
# fill in values for blocks that are present
101111
bIs = collect(eachblockstoredindex(A))
@@ -104,25 +114,22 @@ function MatrixAlgebraKit.initialize_output(
104114
for bI in eachblockstoredindex(A)
105115
row, col = Int.(Tuple(bI))
106116
nrows = brows[row]
107-
slengths[col] = nrows
117+
s_axes[col] = nrows
108118
end
109119

110120
# fill in values for blocks that aren't present, pairing them in order of occurence
111121
# this is a convention, which at least gives the expected results for blockdiagonal
112122
emptyrows = setdiff(1:bm, browIs)
113123
emptycols = setdiff(1:bn, bcolIs)
114124
for (row, col) in zip(emptyrows, emptycols)
115-
slengths[col] = brows[row]
125+
s_axes[col] = brows[row]
116126
end
117127
for (i, k) in enumerate((length(emptycols) + 1):length(emptyrows))
118-
slengths[bn + i] = brows[emptyrows[k]]
128+
s_axes[bn + i] = brows[emptyrows[k]]
119129
end
120130

121-
s_axis = blockedrange(slengths)
122-
U = similar(A, axes(A, 1), s_axis)
123-
Tr = real(eltype(A))
124-
S = similar(A, Tr, (s_axis, axes(A, 2)))
125-
Vt = similar(A, axes(A, 2), axes(A, 2))
131+
s_axis = mortar_axis(s_axes)
132+
U, S, Vt = similar_output(svd_full!, A, s_axis, alg)
126133

127134
# allocate output
128135
for bI in eachblockstoredindex(A)

0 commit comments

Comments
 (0)