Skip to content

Commit 7b4dc77

Browse files
committed
Start rewrite
1 parent f2ec9d8 commit 7b4dc77

File tree

6 files changed

+172
-727
lines changed

6 files changed

+172
-727
lines changed

src/KroneckerArrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ include("linearalgebra.jl")
88
include("matrixalgebrakit.jl")
99
include("fillarrays/kroneckerarray.jl")
1010
include("fillarrays/linearalgebra.jl")
11-
include("fillarrays/matrixalgebrakit.jl")
12-
include("fillarrays/matrixalgebrakit_truncate.jl")
11+
# include("fillarrays/matrixalgebrakit.jl")
12+
# include("fillarrays/matrixalgebrakit_truncate.jl")
1313

1414
end

src/linearalgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
3030
end
3131

3232
function LinearAlgebra.diag(a::KroneckerArray)
33-
return copy(diagview(a))
33+
return copy(DiagonalArrays.diagview(a))
3434
end
3535

3636
# Allows customizing multiplication for specific types

src/matrixalgebrakit.jl

Lines changed: 166 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,60 @@ using MatrixAlgebraKit:
99
default_qr_algorithm,
1010
default_svd_algorithm,
1111
eig_full!,
12+
eig_full,
1213
eig_trunc!,
14+
eig_trunc,
1315
eig_vals!,
16+
eig_vals,
1417
eigh_full!,
18+
eigh_full,
1519
eigh_trunc!,
20+
eigh_trunc,
1621
eigh_vals!,
22+
eigh_vals,
1723
initialize_output,
1824
left_null!,
25+
left_null,
1926
left_orth!,
27+
left_orth,
2028
left_polar!,
29+
left_polar,
2130
lq_compact!,
31+
lq_compact,
2232
lq_full!,
33+
lq_full,
2334
qr_compact!,
35+
qr_compact,
2436
qr_full!,
37+
qr_full,
2538
right_null!,
39+
right_null,
2640
right_orth!,
41+
right_orth,
2742
right_polar!,
43+
right_polar,
2844
svd_compact!,
45+
svd_compact,
2946
svd_full!,
47+
svd_full,
3048
svd_trunc!,
49+
svd_trunc,
3150
svd_vals!,
51+
svd_vals,
3252
truncate!
3353

34-
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
35-
# Allow customization for `Eye`.
36-
_diagview(a::AbstractMatrix) = diagview(a)
37-
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
38-
return _diagview(a.a) _diagview(a.b)
54+
using DiagonalArrays: DiagonalArrays, diagview
55+
function DiagonalArrays.diagview(a::KroneckerMatrix)
56+
return diagview(arg1(a)) diagview(arg2(a))
3957
end
58+
MatrixAlgebraKit.diagview(a::KroneckerMatrix) = diagview(a)
4059

4160
struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
42-
a::A
43-
b::B
61+
arg1::A
62+
arg2::B
4463
end
64+
arg1(alg::KroneckerAlgorithm) = alg.arg1
65+
arg2(alg::KroneckerAlgorithm) = alg.arg2
4566

4667
using MatrixAlgebraKit:
4768
copy_input,
@@ -62,10 +83,6 @@ using MatrixAlgebraKit:
6283
svd_compact,
6384
svd_full
6485

65-
function _copy_input(f::F, a::AbstractMatrix) where {F}
66-
return copy_input(f, a)
67-
end
68-
6986
for f in [
7087
:eig_full,
7188
:eigh_full,
@@ -80,7 +97,7 @@ for f in [
8097
]
8198
@eval begin
8299
function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix)
83-
return _copy_input($f, a.a) _copy_input($f, a.b)
100+
return copy_input($f, arg1(a)) copy_input($f, arg2(a))
84101
end
85102
end
86103
end
@@ -93,105 +110,183 @@ for f in [
93110
:default_polar_algorithm,
94111
:default_svd_algorithm,
95112
]
96-
_f = Symbol(:_, f)
97113
@eval begin
98-
function $_f(A::Type{<:AbstractMatrix}; kwargs...)
99-
return $f(A; kwargs...)
100-
end
101114
function MatrixAlgebraKit.$f(
102115
A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs...
103116
)
104117
A1, A2 = argument_types(A)
105118
return KroneckerAlgorithm(
106-
$_f(A1; kwargs..., kwargs1...), $_f(A2; kwargs..., kwargs2...)
119+
$f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...)
107120
)
108121
end
109122
end
110123
end
111124

112-
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
113-
function MatrixAlgebraKit.default_algorithm(
114-
::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs...
115-
)
116-
return default_qr_algorithm(A; kwargs...)
117-
end
118-
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
119-
function MatrixAlgebraKit.default_algorithm(
120-
::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs...
121-
)
122-
return default_qr_algorithm(A; kwargs...)
123-
end
124-
125-
# Allows overloading while avoiding type piracy.
126-
function _initialize_output(f::F, a::AbstractMatrix, alg::AbstractAlgorithm) where {F}
127-
return initialize_output(f, a, alg)
128-
end
129-
_initialize_output(f::F, a::AbstractMatrix) where {F} = initialize_output(f, a)
130-
131125
for f in [
132-
:eig_full!,
133-
:eigh_full!,
134-
:qr_compact!,
135-
:qr_full!,
136-
:left_polar!,
137-
:lq_compact!,
138-
:lq_full!,
139-
:right_polar!,
140-
:svd_compact!,
141-
:svd_full!,
126+
:eig_full,
127+
:eigh_full,
128+
:left_polar,
129+
:lq_compact,
130+
:lq_full,
131+
:qr_compact,
132+
:qr_full,
133+
:right_polar,
134+
:svd_compact,
135+
:svd_full,
142136
]
137+
f! = Symbol(f, :!)
143138
@eval begin
144139
function MatrixAlgebraKit.initialize_output(
145-
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
140+
::typeof($f!), a, alg::KroneckerAlgorithm
146141
)
147-
return _initialize_output($f, a.a, alg.a) .⊗ _initialize_output($f, a.b, alg.b)
142+
return nothing
148143
end
149-
function MatrixAlgebraKit.$f(
144+
function MatrixAlgebraKit.$f!(
150145
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
151146
)
152-
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
153-
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
154-
return F
147+
a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...)
148+
a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...)
149+
return a1 .⊗ a2
155150
end
156151
end
157152
end
158153

159-
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
154+
for f in [
155+
:eig_vals,
156+
:eigh_vals,
157+
:svd_vals,
158+
]
159+
f! = Symbol(f, :!)
160160
@eval begin
161161
function MatrixAlgebraKit.initialize_output(
162-
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
162+
::typeof($f!), a, alg::KroneckerAlgorithm
163163
)
164-
return _initialize_output($f, a.a, alg.a) _initialize_output($f, a.b, alg.b)
164+
return nothing
165165
end
166-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
167-
$f(a.a, F.a, alg.a)
168-
$f(a.b, F.b, alg.b)
169-
return F
166+
function MatrixAlgebraKit.$f!(
167+
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
168+
)
169+
a1 = $f(arg1(a), arg1(alg); kwargs..., kwargs1...)
170+
a2 = $f(arg2(a), arg2(alg); kwargs..., kwargs2...)
171+
return a1 a2
170172
end
171173
end
172174
end
173175

174-
for f in [:left_orth!, :right_orth!]
176+
for f in [:left_orth, :right_orth]
177+
f! = Symbol(f, :!)
175178
@eval begin
176-
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
177-
return _initialize_output($f, a.a) .⊗ _initialize_output($f, a.b)
179+
function MatrixAlgebraKit.initialize_output(::typeof($f!), a::KroneckerMatrix)
180+
return nothing
181+
end
182+
function MatrixAlgebraKit.$f!(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
183+
a1 = $f(arg1(a); kwargs..., kwargs1...)
184+
a2 = $f(arg2(a); kwargs..., kwargs2...)
185+
return a1 .⊗ a2
178186
end
179187
end
180188
end
181189

182-
for f in [:left_null!, :right_null!]
183-
_f = Symbol(:_, f)
190+
for f in [:left_null, :right_null]
191+
f! = Symbol(f, :!)
184192
@eval begin
185193
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
186-
return _initialize_output($f, a.a) _initialize_output($f, a.b)
194+
return nothing
187195
end
188-
function $_f(a::AbstractMatrix, F; kwargs...)
189-
return $f(a, F; kwargs...)
196+
function MatrixAlgebraKit.$f!(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
197+
a1 = $f(arg1(a); kwargs..., kwargs1...)
198+
a2 = $f(arg2(a); kwargs..., kwargs2...)
199+
return a1 a2
190200
end
191-
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
192-
$_f(a.a, F.a; kwargs..., kwargs1...)
193-
$_f(a.b, F.b; kwargs..., kwargs2...)
194-
return F
201+
end
202+
end
203+
204+
# Truncation
205+
206+
using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate!
207+
208+
struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
209+
strategy::T
210+
end
211+
212+
## # Avoid instantiating the identity.
213+
## function Base.getindex(a::EyeKronecker, I::Vararg{CartesianProduct{Colon},2})
214+
## return a.a ⊗ a.b[I[1].b, I[2].b]
215+
## end
216+
## function Base.getindex(a::KroneckerEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
217+
## return a.a[I[1].a, I[2].a] ⊗ a.b
218+
## end
219+
## function Base.getindex(a::EyeEye, I::Vararg{CartesianProduct{Colon,Colon},2})
220+
## return a
221+
## end
222+
223+
## using FillArrays: OnesVector
224+
## const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
225+
## const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
226+
## const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
227+
228+
axis(a) = only(axes(a))
229+
230+
## # Convert indices determined with a generic call to `findtruncated` to indices
231+
## # more suited for a KroneckerVector.
232+
## function to_truncated_indices(values::OnesKroneckerVector, I)
233+
## prods = cartesianproduct(axis(values))[I]
234+
## I_id = only(to_indices(arg1(values), (:,)))
235+
## I_data = unique(arg2.(prods))
236+
## # Drop truncations that occur within the identity.
237+
## I_data = filter(I_data) do i
238+
## return count(x -> arg2(x) == i, prods) == length(arg2(values))
239+
## end
240+
## return I_id × I_data
241+
## end
242+
## function to_truncated_indices(values::KroneckerOnesVector, I)
243+
## #I = findtruncated(Vector(values), strategy.strategy)
244+
## prods = cartesianproduct(axis(values))[I]
245+
## I_data = unique(arg1.(prods))
246+
## # Drop truncations that occur within the identity.
247+
## I_data = filter(I_data) do i
248+
## return count(x -> arg1(x) == i, prods) == length(arg2(values))
249+
## end
250+
## I_id = only(to_indices(arg2(values), (:,)))
251+
## return I_data × I_id
252+
## end
253+
function to_truncated_indices(values::KroneckerVector, I)
254+
return throw(ArgumentError("Not implemented"))
255+
end
256+
257+
function MatrixAlgebraKit.findtruncated(
258+
values::KroneckerVector, strategy::KroneckerTruncationStrategy
259+
)
260+
I = findtruncated(Vector(values), strategy.strategy)
261+
return to_truncated_indices(values, I)
262+
end
263+
264+
for f in [:eig_trunc!, :eigh_trunc!]
265+
@eval begin
266+
function MatrixAlgebraKit.truncate!(
267+
::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy
268+
)
269+
return truncate!($f, DV, KroneckerTruncationStrategy(strategy))
270+
end
271+
function MatrixAlgebraKit.truncate!(
272+
::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy
273+
)
274+
I = findtruncated(diagview(D), strategy)
275+
return (D[I, I], V[(:) × (:), I])
195276
end
196277
end
197278
end
279+
280+
function MatrixAlgebraKit.truncate!(
281+
f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy
282+
)
283+
return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy))
284+
end
285+
function MatrixAlgebraKit.truncate!(
286+
::typeof(svd_trunc!),
287+
(U, S, Vᴴ)::NTuple{3,KroneckerMatrix},
288+
strategy::KroneckerTruncationStrategy,
289+
)
290+
I = findtruncated(diagview(S), strategy)
291+
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
292+
end

0 commit comments

Comments
 (0)