Skip to content

Commit 41dd2a9

Browse files
committed
add eigvals and svdvals, refactor transformers
1 parent 710939b commit 41dd2a9

File tree

5 files changed

+122
-44
lines changed

5 files changed

+122
-44
lines changed

src/factorizations/eig.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,9 @@
11
using BlockArrays: blocksizes
22
using DiagonalArrays: diagonal
33
using LinearAlgebra: LinearAlgebra, Diagonal
4-
using MatrixAlgebraKit:
5-
MatrixAlgebraKit,
6-
TruncationStrategy,
7-
default_eig_algorithm,
8-
default_eigh_algorithm,
9-
diagview,
10-
eig_full!,
11-
eig_trunc!,
12-
eig_vals!,
13-
eigh_full!,
14-
eigh_trunc!,
15-
eigh_vals!,
16-
findtruncated
4+
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
5+
using MatrixAlgebraKit: default_eig_algorithm, eig_full!, eig_vals!
6+
using MatrixAlgebraKit: default_eigh_algorithm, eigh_full!, eigh_vals!
177

188
for f in [:default_eig_algorithm, :default_eigh_algorithm]
199
@eval begin
@@ -96,10 +86,10 @@ for f in [:eig_full!, :eigh_full!]
9686
A::AbstractBlockSparseMatrix, DV, alg::BlockPermutedDiagonalAlgorithm
9787
)
9888
MatrixAlgebraKit.check_input($f, A, DV, alg)
99-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
89+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
10090
Dd, Vd = $f(Ad, BlockDiagonalAlgorithm(alg))
101-
D = transform_rows(Dd)
102-
V = transform_cols(Vd)
91+
D = transform_rows(Dd, invrowperm)
92+
V = transform_cols(Vd, invcolperm)
10393
return D, V
10494
end
10595
function MatrixAlgebraKit.$f(
@@ -140,16 +130,47 @@ for f in [:eig_vals!, :eigh_vals!]
140130
@eval begin
141131
function MatrixAlgebraKit.initialize_output(
142132
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm
133+
)
134+
return nothing
135+
end
136+
function MatrixAlgebraKit.initialize_output(
137+
::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
143138
)
144139
T = output_type($f, blocktype(A))
145140
return similar(A, BlockType(T), axes(A, 1))
146141
end
142+
function MatrixAlgebraKit.check_input(
143+
::typeof($f), A::AbstractBlockSparseMatrix, D, ::BlockPermutedDiagonalAlgorithm
144+
)
145+
@assert isblockpermuteddiagonal(A)
146+
return nothing
147+
end
148+
function MatrixAlgebraKit.check_input(
149+
::typeof($f), A::AbstractBlockSparseMatrix, D, ::BlockDiagonalAlgorithm
150+
)
151+
@assert isa(D, AbstractBlockSparseVector)
152+
@assert eltype(D) === $(f == :eig_vals! ? complex : real)(eltype(A))
153+
@assert axes(A, 1) == axes(A, 2)
154+
@assert (axes(A, 1),) == axes(D)
155+
@assert isblockdiagonal(A)
156+
return nothing
157+
end
158+
147159
function MatrixAlgebraKit.$f(
148160
A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm
149161
)
162+
MatrixAlgebraKit.check_input($f, A, D, alg)
163+
Ad, (invrowperm, _) = blockdiagonalize(A)
164+
Dd = $f(Ad, BlockDiagonalAlgorithm(alg))
165+
return transform_rows(Dd, invrowperm)
166+
end
167+
function MatrixAlgebraKit.$f(
168+
A::AbstractBlockSparseMatrix, D, alg::BlockDiagonalAlgorithm
169+
)
170+
MatrixAlgebraKit.check_input($f, A, D, alg)
150171
for I in eachblockstoredindex(A)
151172
block = @view!(A[I])
152-
D[I] = $f(block, block_algorithm(alg, block))
173+
D[Tuple(I)[1]] = $f(block, block_algorithm(alg, block))
153174
end
154175
return D
155176
end

src/factorizations/lq.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ function MatrixAlgebraKit.lq_compact!(
8787
A::AbstractBlockSparseMatrix, LQ, alg::BlockPermutedDiagonalAlgorithm
8888
)
8989
MatrixAlgebraKit.check_input(lq_compact!, A, LQ, alg)
90-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
90+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
9191
Ld, Qd = lq_compact!(Ad, BlockDiagonalAlgorithm(alg))
92-
L = transform_rows(Ld)
93-
Q = transform_cols(Qd)
92+
L = transform_rows(Ld, invrowperm)
93+
Q = transform_cols(Qd, invcolperm)
9494
return L, Q
9595
end
9696
function MatrixAlgebraKit.lq_compact!(
@@ -119,10 +119,10 @@ function MatrixAlgebraKit.lq_full!(
119119
A::AbstractBlockSparseMatrix, LQ, alg::BlockPermutedDiagonalAlgorithm
120120
)
121121
MatrixAlgebraKit.check_input(lq_full!, A, LQ, alg)
122-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
122+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
123123
Ld, Qd = lq_full!(Ad, BlockDiagonalAlgorithm(alg))
124-
L = transform_rows(Ld)
125-
Q = transform_cols(Qd)
124+
L = transform_rows(Ld, invrowperm)
125+
Q = transform_cols(Qd, invcolperm)
126126
return L, Q
127127
end
128128
function MatrixAlgebraKit.lq_full!(

src/factorizations/qr.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ function MatrixAlgebraKit.qr_compact!(
8888
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
8989
)
9090
check_input(qr_compact!, A, QR, alg)
91-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
91+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
9292
Qd, Rd = qr_compact!(Ad, BlockDiagonalAlgorithm(alg))
93-
Q = transform_rows(Qd)
94-
R = transform_cols(Rd)
93+
Q = transform_rows(Qd, invrowperm)
94+
R = transform_cols(Rd, invcolperm)
9595
return Q, R
9696
end
9797

@@ -121,10 +121,10 @@ function MatrixAlgebraKit.qr_full!(
121121
A::AbstractBlockSparseMatrix, QR, alg::BlockPermutedDiagonalAlgorithm
122122
)
123123
check_input(qr_full!, A, QR, alg)
124-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
124+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
125125
Qd, Rd = qr_full!(Ad, BlockDiagonalAlgorithm(alg))
126-
Q = transform_rows(Qd)
127-
R = transform_cols(Rd)
126+
Q = transform_rows(Qd, invrowperm)
127+
R = transform_cols(Rd, invcolperm)
128128
return Q, R
129129
end
130130

src/factorizations/svd.jl

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using DiagonalArrays: diagonaltype
22
using MatrixAlgebraKit:
3-
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!
3+
MatrixAlgebraKit, check_input, default_svd_algorithm, svd_compact!, svd_full!, svd_vals!
44
using TypeParameterAccessors: realtype
55

66
function MatrixAlgebraKit.default_svd_algorithm(
@@ -15,7 +15,15 @@ function output_type(
1515
f::Union{typeof(svd_compact!),typeof(svd_full!)}, A::Type{<:AbstractMatrix{T}}
1616
) where {T}
1717
USVᴴ = Base.promote_op(f, A)
18-
return isconcretetype(USVᴴ) ? USVᴴ : Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
18+
return if isconcretetype(USVᴴ)
19+
USVᴴ
20+
else
21+
Tuple{AbstractMatrix{T},AbstractMatrix{realtype(T)},AbstractMatrix{T}}
22+
end
23+
end
24+
function output_type(::typeof(svd_vals!), A::Type{<:AbstractMatrix{T}}) where {T}
25+
S = Base.promote_op(svd_vals!, A)
26+
return isconcretetype(S) ? S : AbstractVector{real(T)}
1927
end
2028

2129
function MatrixAlgebraKit.initialize_output(
@@ -46,7 +54,6 @@ function MatrixAlgebraKit.initialize_output(
4654
)
4755
return nothing
4856
end
49-
5057
function MatrixAlgebraKit.initialize_output(
5158
::typeof(svd_full!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
5259
)
@@ -58,6 +65,24 @@ function MatrixAlgebraKit.initialize_output(
5865
return U, S, Vᴴ
5966
end
6067

68+
function MatrixAlgebraKit.initialize_output(
69+
::typeof(svd_vals!), ::AbstractBlockSparseMatrix, ::BlockDiagonalAlgorithm
70+
)
71+
return nothing
72+
end
73+
function MatrixAlgebraKit.initialize_output(
74+
::typeof(svd_vals!), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm
75+
)
76+
brows = eachblockaxis(axes(A, 1))
77+
bcols = eachblockaxis(axes(A, 2))
78+
# using the property that zip stops as soon as one of the iterators is exhausted
79+
s_axes = map(splat(infimum), zip(brows, bcols))
80+
s_axis = mortar_axis(s_axes)
81+
82+
BS = output_type(svd_vals!, blocktype(A))
83+
return similar(A, BlockType(BS), S_axes)
84+
end
85+
6186
function MatrixAlgebraKit.check_input(
6287
::typeof(svd_compact!),
6388
A::AbstractBlockSparseMatrix,
@@ -66,7 +91,6 @@ function MatrixAlgebraKit.check_input(
6691
)
6792
@assert isblockpermuteddiagonal(A)
6893
end
69-
7094
function MatrixAlgebraKit.check_input(
7195
::typeof(svd_compact!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ), ::BlockDiagonalAlgorithm
7296
)
@@ -87,7 +111,6 @@ function MatrixAlgebraKit.check_input(
87111
@assert isblockpermuteddiagonal(A)
88112
return nothing
89113
end
90-
91114
function MatrixAlgebraKit.check_input(
92115
::typeof(svd_full!), A::AbstractBlockSparseMatrix, (U, S, Vᴴ), ::BlockDiagonalAlgorithm
93116
)
@@ -102,15 +125,30 @@ function MatrixAlgebraKit.check_input(
102125
return nothing
103126
end
104127

128+
function MatrixAlgebraKit.check_input(
129+
::typeof(svd_vals!), A::AbstractBlockSparseMatrix, S, ::BlockPermutedDiagonalAlgorithm
130+
)
131+
@assert isblockpermuteddiagonal(A)
132+
return nothing
133+
end
134+
function MatrixAlgebraKit.check_input(
135+
::typeof(svd_vals!), A::AbstractBlockSparseMatrix, S, ::BlockDiagonalAlgorithm
136+
)
137+
@assert isa(S, AbstractBlockSparseVector)
138+
@assert real(eltype(A)) == eltype(S)
139+
@assert isblockdiagonal(A)
140+
return nothing
141+
end
142+
105143
function MatrixAlgebraKit.svd_compact!(
106144
A::AbstractBlockSparseMatrix, USVᴴ, alg::BlockPermutedDiagonalAlgorithm
107145
)
108146
check_input(svd_compact!, A, USVᴴ, alg)
109147

110-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
148+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
111149
Ud, S, Vᴴd = svd_compact!(Ad, BlockDiagonalAlgorithm(alg))
112-
U = transform_rows(Ud)
113-
Vᴴ = transform_cols(Vᴴd)
150+
U = transform_rows(Ud, invrowperm)
151+
Vᴴ = transform_cols(Vᴴd, invcolperm)
114152

115153
return U, S, Vᴴ
116154
end
@@ -143,10 +181,10 @@ function MatrixAlgebraKit.svd_full!(
143181
)
144182
check_input(svd_full!, A, USVᴴ, alg)
145183

146-
Ad, transform_rows, transform_cols = blockdiagonalize(A)
184+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
147185
Ud, S, Vᴴd = svd_full!(Ad, BlockDiagonalAlgorithm(alg))
148-
U = transform_rows(Ud)
149-
Vᴴ = transform_cols(Vᴴd)
186+
U = transform_rows(Ud, invrowperm)
187+
Vᴴ = transform_cols(Vᴴd, invcolperm)
150188

151189
return U, S, Vᴴ
152190
end
@@ -181,3 +219,21 @@ function MatrixAlgebraKit.svd_full!(
181219

182220
return U, S, Vᴴ
183221
end
222+
223+
function MatrixAlgebraKit.svd_vals!(
224+
A::AbstractBlockSparseMatrix, S, alg::BlockPermutedDiagonalAlgorithm
225+
)
226+
MatrixAlgebraKit.check_input(svd_vals!, A, S, alg)
227+
Ad, _ = blockdiagonalize(A)
228+
return svd_vals!(Ad, BlockDiagonalAlgorithm(alg))
229+
end
230+
function MatrixAlgebraKit.svd_vals!(
231+
A::AbstractBlockSparseMatrix, S, alg::BlockDiagonalAlgorithm
232+
)
233+
MatrixAlgebraKit.check_input(svd_vals!, A, S, alg)
234+
for I in eachblockstoredindex(A)
235+
block = @view!(A[I])
236+
S[Tuple(I)[1]] = $f(block, block_algorithm(alg, block))
237+
end
238+
return S
239+
end

src/factorizations/utility.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ function supremum(r1::AbstractUnitRange, r2::AbstractUnitRange)
2020
end
2121
end
2222

23+
transform_rows(A::AbstractMatrix, invrowperm) = A[invrowperm, :]
24+
transform_rows(A::AbstractVector, invrowperm) = A[invrowperm]
25+
transform_cols(A::AbstractMatrix, invcolperm) = A[:, invcolperm]
26+
2327
function blockdiagonalize(A::AbstractBlockSparseMatrix)
2428
# sort in order to avoid depending on internal details such as dictionary order
2529
bIs = sort!(collect(eachblockstoredindex(A)); by=Int last Tuple)
@@ -41,12 +45,9 @@ function blockdiagonalize(A::AbstractBlockSparseMatrix)
4145
append!(colperm, emptycols)
4246

4347
invrowperm = Block.(invperm(Int.(rowperm)))
44-
transform_rows(A) = A[invrowperm, :]
45-
4648
invcolperm = Block.(invperm(Int.(colperm)))
47-
transform_cols(A) = A[:, invcolperm]
4849

49-
return A[rowperm, colperm], transform_rows, transform_cols
50+
return A[rowperm, colperm], (invrowperm, invcolperm)
5051
end
5152

5253
function isblockdiagonal(A::AbstractBlockSparseMatrix)

0 commit comments

Comments
 (0)