Skip to content

Commit 4c632e8

Browse files
committed
change transform into functions
1 parent 5457095 commit 4c632e8

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

src/factorizations/svd.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,10 @@ function MatrixAlgebraKit.svd_compact!(
138138
)
139139
check_input(svd_compact!, A, USVᴴ, alg)
140140

141-
Ad, rowperm, colperm = blockdiagonalize(A)
141+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
142142
Ud, S, Vᴴd = svd_compact!(Ad, BlockDiagonalAlgorithm(alg))
143-
144-
inv_rowperm = Block.(invperm(Int.(rowperm)))
145-
U = Ud[inv_rowperm, :]
146-
147-
inv_colperm = Block.(invperm(Int.(colperm)))
148-
Vᴴ = Vᴴd[:, inv_colperm]
143+
U = transform_rows(Ud)
144+
Vᴴ = transform_cols(Vᴴd)
149145

150146
return U, S, Vᴴ
151147
end
@@ -177,14 +173,10 @@ function MatrixAlgebraKit.svd_full!(
177173
)
178174
check_input(svd_full!, A, USVᴴ, alg)
179175

180-
Ad, rowperm, colperm = blockdiagonalize(A)
176+
Ad, transform_rows, transform_cols = blockdiagonalize(A)
181177
Ud, S, Vᴴd = svd_full!(Ad, BlockDiagonalAlgorithm(alg))
182-
183-
inv_rowperm = Block.(invperm(Int.(rowperm)))
184-
U = Ud[inv_rowperm, :]
185-
186-
inv_colperm = Block.(invperm(Int.(colperm)))
187-
Vᴴ = Vᴴd[:, inv_colperm]
178+
U = transform_rows(Ud)
179+
Vᴴ = transform_cols(Vᴴd)
188180

189181
return U, S, Vᴴ
190182
end

src/factorizations/utility.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ function blockdiagonalize(A::AbstractBlockSparseMatrix)
4040
emptycols = setdiff(Block.(1:blocksize(A, 2)), colperm)
4141
append!(colperm, emptycols)
4242

43-
return A[rowperm, colperm], rowperm, colperm
43+
invrowperm = Block.(invperm(Int.(rowperm)))
44+
transform_rows(A) = A[invrowperm, :]
45+
46+
invcolperm = Block.(invperm(Int.(colperm)))
47+
transform_cols(A) = A[:, invcolperm]
48+
49+
return A[rowperm, colperm], transform_rows, transform_cols
4450
end
4551

4652
function isblockdiagonal(A::AbstractBlockSparseMatrix)

0 commit comments

Comments
 (0)