Skip to content

Commit 703b936

Browse files
committed
fix matrixalgebrakit
1 parent 500506a commit 703b936

File tree

2 files changed

+51
-109
lines changed

2 files changed

+51
-109
lines changed

src/kroneckerarray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ i.e. that can be written as `AB = A ⊗ B`.
66
"""
77
abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end
88

9+
const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1}
10+
const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2}
11+
912
@doc """
1013
arg1(AB::AbstractKroneckerArray{T, N})
1114

src/matrixalgebrakit.jl

Lines changed: 48 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,23 @@
11
using MatrixAlgebraKit:
22
MatrixAlgebraKit,
3-
AbstractAlgorithm,
4-
TruncationStrategy,
5-
default_eig_algorithm,
6-
default_eigh_algorithm,
7-
default_lq_algorithm,
8-
default_polar_algorithm,
9-
default_qr_algorithm,
10-
default_svd_algorithm,
11-
eig_full!,
12-
eig_full,
13-
eig_trunc!,
14-
eig_trunc,
15-
eig_vals!,
16-
eig_vals,
17-
eigh_full!,
18-
eigh_full,
19-
eigh_trunc!,
20-
eigh_trunc,
21-
eigh_vals!,
22-
eigh_vals,
3+
AbstractAlgorithm, TruncationStrategy,
4+
default_eig_algorithm, default_eigh_algorithm, default_lq_algorithm,
5+
default_polar_algorithm, default_qr_algorithm, default_svd_algorithm,
6+
eig_full!, eig_full, eig_trunc!, eig_trunc, eig_vals!, eig_vals,
7+
eigh_full!, eigh_full, eigh_trunc!, eigh_trunc, eigh_vals!, eigh_vals,
238
initialize_output,
24-
left_null!,
25-
left_null,
26-
left_orth!,
27-
left_orth,
28-
left_polar!,
29-
left_polar,
30-
lq_compact!,
31-
lq_compact,
32-
lq_full!,
33-
lq_full,
34-
qr_compact!,
35-
qr_compact,
36-
qr_full!,
37-
qr_full,
38-
right_null!,
39-
right_null,
40-
right_orth!,
41-
right_orth,
42-
right_polar!,
43-
right_polar,
44-
svd_compact!,
45-
svd_compact,
46-
svd_full!,
47-
svd_full,
48-
svd_trunc!,
49-
svd_trunc,
50-
svd_vals!,
51-
svd_vals,
9+
left_null!, left_null, left_orth!, left_orth, left_polar!, left_polar,
10+
lq_compact!, lq_compact, lq_full!, lq_full,
11+
qr_compact!, qr_compact, qr_full!, qr_full,
12+
right_null!, right_null, right_orth!, right_orth, right_polar!, right_polar,
13+
svd_compact!, svd_compact, svd_full!, svd_full, svd_trunc!, svd_trunc, svd_vals!, svd_vals,
5214
truncate
5315

5416
using DiagonalArrays: DiagonalArrays, diagview
55-
function DiagonalArrays.diagview(a::KroneckerMatrix)
17+
function DiagonalArrays.diagview(a::AbstractKroneckerMatrix)
5618
return diagview(arg1(a)) diagview(arg2(a))
5719
end
58-
MatrixAlgebraKit.diagview(a::KroneckerMatrix) = diagview(a)
20+
MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = diagview(a)
5921

6022
struct KroneckerAlgorithm{A1, A2} <: AbstractAlgorithm
6123
arg1::A1
@@ -66,53 +28,35 @@ end
6628

6729
using MatrixAlgebraKit:
6830
copy_input,
69-
eig_full,
70-
eig_vals,
71-
eigh_full,
72-
eigh_vals,
73-
qr_compact,
74-
qr_full,
75-
left_null,
76-
left_orth,
77-
left_polar,
78-
lq_compact,
79-
lq_full,
80-
right_null,
81-
right_orth,
82-
right_polar,
83-
svd_compact,
84-
svd_full
31+
eig_full, eig_vals, eigh_full, eigh_vals,
32+
qr_compact, qr_full,
33+
left_null, left_orth, left_polar,
34+
lq_compact, lq_full,
35+
right_null, right_orth, right_polar,
36+
svd_compact, svd_full
8537

8638
for f in [
87-
:eig_full,
88-
:eigh_full,
89-
:qr_compact,
90-
:qr_full,
91-
:left_polar,
92-
:lq_compact,
93-
:lq_full,
94-
:right_polar,
95-
:svd_compact,
96-
:svd_full,
39+
:eig_full, :eigh_full,
40+
:qr_compact, :qr_full,
41+
:lq_compact, :lq_full,
42+
:left_polar, :right_polar,
43+
:svd_compact, :svd_full,
9744
]
9845
@eval begin
99-
function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix)
46+
function MatrixAlgebraKit.copy_input(::typeof($f), a::AbstractKroneckerMatrix)
10047
return copy_input($f, arg1(a)) copy_input($f, arg2(a))
10148
end
10249
end
10350
end
10451

10552
for f in [
106-
:default_eig_algorithm,
107-
:default_eigh_algorithm,
108-
:default_lq_algorithm,
109-
:default_qr_algorithm,
110-
:default_polar_algorithm,
111-
:default_svd_algorithm,
53+
:default_eig_algorithm, :default_eigh_algorithm,
54+
:default_lq_algorithm, :default_qr_algorithm,
55+
:default_polar_algorithm, :default_svd_algorithm,
11256
]
11357
@eval begin
11458
function MatrixAlgebraKit.$f(
115-
A::Type{<:KroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...
59+
A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...
11660
)
11761
A1, A2 = argument_types(A)
11862
return KroneckerAlgorithm(
@@ -123,16 +67,11 @@ for f in [
12367
end
12468

12569
for f in [
126-
:eig_full,
127-
:eigh_full,
128-
:left_polar,
129-
:lq_compact,
130-
:lq_full,
131-
:qr_compact,
132-
:qr_full,
133-
:right_polar,
134-
:svd_compact,
135-
:svd_full,
70+
:eig_full, :eigh_full,
71+
:left_polar, :right_polar,
72+
:lq_compact, :lq_full,
73+
:qr_compact, :qr_full,
74+
:svd_compact, :svd_full,
13675
]
13776
f! = Symbol(f, :!)
13877
@eval begin
@@ -142,10 +81,10 @@ for f in [
14281
return nothing
14382
end
14483
function MatrixAlgebraKit.$f!(
145-
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs...
84+
a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm
14685
)
147-
a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...)
148-
a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...)
86+
a1 = $f(arg1(a), arg1(alg))
87+
a2 = $f(arg2(a), arg2(alg))
14988
return a1 .⊗ a2
15089
end
15190
end
@@ -160,10 +99,10 @@ for f in [:eig_vals, :eigh_vals, :svd_vals]
16099
return nothing
161100
end
162101
function MatrixAlgebraKit.$f!(
163-
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs...
102+
a::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm
164103
)
165-
a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...)
166-
a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...)
104+
a1 = $f(arg1(a), arg1(alg))
105+
a2 = $f(arg2(a), arg2(alg))
167106
return a1 a2
168107
end
169108
end
@@ -172,11 +111,11 @@ end
172111
for f in [:left_orth, :right_orth]
173112
f! = Symbol(f, :!)
174113
@eval begin
175-
function MatrixAlgebraKit.initialize_output(::typeof($f!), a::KroneckerMatrix)
114+
function MatrixAlgebraKit.initialize_output(::typeof($f!), a::AbstractKroneckerMatrix)
176115
return nothing
177116
end
178117
function MatrixAlgebraKit.$f!(
179-
a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...
118+
a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...
180119
)
181120
a1 = $f(arg1(a); kwargs..., kwargs1...)
182121
a2 = $f(arg2(a); kwargs..., kwargs2...)
@@ -188,11 +127,11 @@ end
188127
for f in [:left_null, :right_null]
189128
f! = Symbol(f, :!)
190129
@eval begin
191-
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
130+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::AbstractKroneckerMatrix)
192131
return nothing
193132
end
194133
function MatrixAlgebraKit.$f!(
195-
a::KroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...
134+
a::AbstractKroneckerMatrix, F; kwargs1 = (;), kwargs2 = (;), kwargs...
196135
)
197136
a1 = $f(arg1(a); kwargs..., kwargs1...)
198137
a2 = $f(arg2(a); kwargs..., kwargs2...)
@@ -248,7 +187,7 @@ function to_truncated_indices(values::KroneckerVector, I)
248187
end
249188

250189
function MatrixAlgebraKit.findtruncated(
251-
values::KroneckerVector, strategy::KroneckerTruncationStrategy
190+
values::AbstractKroneckerVector, strategy::KroneckerTruncationStrategy
252191
)
253192
I = findtruncated(Vector(values), strategy.strategy)
254193
return to_truncated_indices(values, I)
@@ -257,12 +196,12 @@ end
257196
for f in [:eig_trunc!, :eigh_trunc!]
258197
@eval begin
259198
function MatrixAlgebraKit.truncate(
260-
::typeof($f), DV::NTuple{2, KroneckerMatrix}, strategy::TruncationStrategy
199+
::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy
261200
)
262201
return truncate($f, DV, KroneckerTruncationStrategy(strategy))
263202
end
264203
function MatrixAlgebraKit.truncate(
265-
::typeof($f), (D, V)::NTuple{2, KroneckerMatrix}, strategy::KroneckerTruncationStrategy
204+
::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy
266205
)
267206
I = findtruncated(diagview(D), strategy)
268207
return (D[I, I], V[(:) × (:), I]), I
@@ -271,13 +210,13 @@ for f in [:eig_trunc!, :eigh_trunc!]
271210
end
272211

273212
function MatrixAlgebraKit.truncate(
274-
f::typeof(svd_trunc!), USVᴴ::NTuple{3, KroneckerMatrix}, strategy::TruncationStrategy
213+
f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy
275214
)
276215
return truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy))
277216
end
278217
function MatrixAlgebraKit.truncate(
279218
::typeof(svd_trunc!),
280-
(U, S, Vᴴ)::NTuple{3, KroneckerMatrix},
219+
(U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix},
281220
strategy::KroneckerTruncationStrategy,
282221
)
283222
I = findtruncated(diagview(S), strategy)

0 commit comments

Comments
 (0)