@@ -9,39 +9,60 @@ using MatrixAlgebraKit:
99 default_qr_algorithm,
1010 default_svd_algorithm,
1111 eig_full!,
12+ eig_full,
1213 eig_trunc!,
14+ eig_trunc,
1315 eig_vals!,
16+ eig_vals,
1417 eigh_full!,
18+ eigh_full,
1519 eigh_trunc!,
20+ eigh_trunc,
1621 eigh_vals!,
22+ eigh_vals,
1723 initialize_output,
1824 left_null!,
25+ left_null,
1926 left_orth!,
27+ left_orth,
2028 left_polar!,
29+ left_polar,
2130 lq_compact!,
31+ lq_compact,
2232 lq_full!,
33+ lq_full,
2334 qr_compact!,
35+ qr_compact,
2436 qr_full!,
37+ qr_full,
2538 right_null!,
39+ right_null,
2640 right_orth!,
41+ right_orth,
2742 right_polar!,
43+ right_polar,
2844 svd_compact!,
45+ svd_compact,
2946 svd_full!,
47+ svd_full,
3048 svd_trunc!,
49+ svd_trunc,
3150 svd_vals!,
51+ svd_vals,
3252 truncate!
3353
34- using MatrixAlgebraKit: MatrixAlgebraKit, diagview
35- # Allow customization for `Eye`.
36- _diagview (a:: AbstractMatrix ) = diagview (a)
37- function MatrixAlgebraKit. diagview (a:: KroneckerMatrix )
38- return _diagview (a. a) ⊗ _diagview (a. b)
54+ using DiagonalArrays: DiagonalArrays, diagview
55+ function DiagonalArrays. diagview (a:: KroneckerMatrix )
56+ return diagview (arg1 (a)) ⊗ diagview (arg2 (a))
3957end
58+ MatrixAlgebraKit. diagview (a:: KroneckerMatrix ) = diagview (a)
4059
4160struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
42- a :: A
43- b :: B
61+ arg1 :: A
62+ arg2 :: B
4463end
64+ arg1 (alg:: KroneckerAlgorithm ) = alg. arg1
65+ arg2 (alg:: KroneckerAlgorithm ) = alg. arg2
4566
4667using MatrixAlgebraKit:
4768 copy_input,
@@ -62,10 +83,6 @@ using MatrixAlgebraKit:
6283 svd_compact,
6384 svd_full
6485
65- function _copy_input (f:: F , a:: AbstractMatrix ) where {F}
66- return copy_input (f, a)
67- end
68-
6986for f in [
7087 :eig_full ,
7188 :eigh_full ,
@@ -80,7 +97,7 @@ for f in [
8097]
8198 @eval begin
8299 function MatrixAlgebraKit. copy_input (:: typeof ($ f), a:: KroneckerMatrix )
83- return _copy_input ($ f, a . a) ⊗ _copy_input ($ f, a . b )
100+ return copy_input ($ f, arg1 (a)) ⊗ copy_input ($ f, arg2 (a) )
84101 end
85102 end
86103end
@@ -93,105 +110,183 @@ for f in [
93110 :default_polar_algorithm ,
94111 :default_svd_algorithm ,
95112]
96- _f = Symbol (:_ , f)
97113 @eval begin
98- function $_f (A:: Type{<:AbstractMatrix} ; kwargs... )
99- return $ f (A; kwargs... )
100- end
101114 function MatrixAlgebraKit. $f (
102115 A:: Type{<:KroneckerMatrix} ; kwargs1= (;), kwargs2= (;), kwargs...
103116 )
104117 A1, A2 = argument_types (A)
105118 return KroneckerAlgorithm (
106- $ _f (A1; kwargs... , kwargs1... ), $ _f (A2; kwargs... , kwargs2... )
119+ $ f (A1; kwargs... , kwargs1... ), $ f (A2; kwargs... , kwargs2... )
107120 )
108121 end
109122 end
110123end
111124
112- # TODO : Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
113- function MatrixAlgebraKit. default_algorithm (
114- :: typeof (qr_compact!), A:: Type{<:KroneckerMatrix} ; kwargs...
115- )
116- return default_qr_algorithm (A; kwargs... )
117- end
118- # TODO : Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
119- function MatrixAlgebraKit. default_algorithm (
120- :: typeof (qr_full!), A:: Type{<:KroneckerMatrix} ; kwargs...
121- )
122- return default_qr_algorithm (A; kwargs... )
123- end
124-
125- # Allows overloading while avoiding type piracy.
126- function _initialize_output (f:: F , a:: AbstractMatrix , alg:: AbstractAlgorithm ) where {F}
127- return initialize_output (f, a, alg)
128- end
129- _initialize_output (f:: F , a:: AbstractMatrix ) where {F} = initialize_output (f, a)
130-
131125for f in [
132- :eig_full! ,
133- :eigh_full! ,
134- :qr_compact! ,
135- :qr_full! ,
136- :left_polar! ,
137- :lq_compact! ,
138- :lq_full! ,
139- :right_polar! ,
140- :svd_compact! ,
141- :svd_full! ,
126+ :eig_full ,
127+ :eigh_full ,
128+ :left_polar ,
129+ :lq_compact ,
130+ :lq_full ,
131+ :qr_compact ,
132+ :qr_full ,
133+ :right_polar ,
134+ :svd_compact ,
135+ :svd_full ,
142136]
137+ f! = Symbol (f, :! )
143138 @eval begin
144139 function MatrixAlgebraKit. initialize_output (
145- :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
140+ :: typeof ($ f! ), a, alg:: KroneckerAlgorithm
146141 )
147- return _initialize_output ( $ f, a . a, alg . a) .⊗ _initialize_output ( $ f, a . b, alg . b)
142+ return nothing
148143 end
149- function MatrixAlgebraKit. $f (
144+ function MatrixAlgebraKit. $f! (
150145 a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
151146 )
152- $ f (a . a, Base . Fix2 (getfield, :a ).(F), alg. a ; kwargs... , kwargs1... )
153- $ f (a . b, Base . Fix2 (getfield, :b ).(F), alg. b ; kwargs... , kwargs2... )
154- return F
147+ a1 = $ f (arg1 (a), arg1 ( alg) ; kwargs... , kwargs1... )
148+ a2 = $ f (arg2 (a), arg2 ( alg) ; kwargs... , kwargs2... )
149+ return a1 .⊗ a2
155150 end
156151 end
157152end
158153
159- for f in [:eig_vals! , :eigh_vals! , :svd_vals! ]
154+ for f in [
155+ :eig_vals ,
156+ :eigh_vals ,
157+ :svd_vals ,
158+ ]
159+ f! = Symbol (f, :! )
160160 @eval begin
161161 function MatrixAlgebraKit. initialize_output (
162- :: typeof ($ f), a:: KroneckerMatrix , alg:: KroneckerAlgorithm
162+ :: typeof ($ f! ), a, alg:: KroneckerAlgorithm
163163 )
164- return _initialize_output ( $ f, a . a, alg . a) ⊗ _initialize_output ( $ f, a . b, alg . b)
164+ return nothing
165165 end
166- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm )
167- $ f (a. a, F. a, alg. a)
168- $ f (a. b, F. b, alg. b)
169- return F
166+ function MatrixAlgebraKit. $f! (
167+ a:: KroneckerMatrix , F, alg:: KroneckerAlgorithm ; kwargs1= (;), kwargs2= (;), kwargs...
168+ )
169+ a1 = $ f (arg1 (a), arg1 (alg); kwargs... , kwargs1... )
170+ a2 = $ f (arg2 (a), arg2 (alg); kwargs... , kwargs2... )
171+ return a1 ⊗ a2
170172 end
171173 end
172174end
173175
174- for f in [:left_orth! , :right_orth! ]
176+ for f in [:left_orth , :right_orth ]
177+ f! = Symbol (f, :! )
175178 @eval begin
176- function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
177- return _initialize_output ($ f, a. a) .⊗ _initialize_output ($ f, a. b)
179+ function MatrixAlgebraKit. initialize_output (:: typeof ($ f!), a:: KroneckerMatrix )
180+ return nothing
181+ end
182+ function MatrixAlgebraKit. $f! (a:: KroneckerMatrix , F; kwargs1= (;), kwargs2= (;), kwargs... )
183+ a1 = $ f (arg1 (a); kwargs... , kwargs1... )
184+ a2 = $ f (arg2 (a); kwargs... , kwargs2... )
185+ return a1 .⊗ a2
178186 end
179187 end
180188end
181189
182- for f in [:left_null! , :right_null! ]
183- _f = Symbol (:_ , f )
190+ for f in [:left_null , :right_null ]
191+ f! = Symbol (f, : ! )
184192 @eval begin
185193 function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: KroneckerMatrix )
186- return _initialize_output ( $ f, a . a) ⊗ _initialize_output ( $ f, a . b)
194+ return nothing
187195 end
188- function $_f (a:: AbstractMatrix , F; kwargs... )
189- return $ f (a, F; kwargs... )
196+ function MatrixAlgebraKit. $f! (a:: KroneckerMatrix , F; kwargs1= (;), kwargs2= (;), kwargs... )
197+ a1 = $ f (arg1 (a); kwargs... , kwargs1... )
198+ a2 = $ f (arg2 (a); kwargs... , kwargs2... )
199+ return a1 ⊗ a2
190200 end
191- function MatrixAlgebraKit. $f (a:: KroneckerMatrix , F; kwargs1= (;), kwargs2= (;), kwargs... )
192- $ _f (a. a, F. a; kwargs... , kwargs1... )
193- $ _f (a. b, F. b; kwargs... , kwargs2... )
194- return F
201+ end
202+ end
203+
204+ # Truncation
205+
206+ using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate!
207+
208+ struct KroneckerTruncationStrategy{T<: TruncationStrategy } <: TruncationStrategy
209+ strategy:: T
210+ end
211+
212+ # # # Avoid instantiating the identity.
213+ # # function Base.getindex(a::EyeKronecker, I::Vararg{CartesianProduct{Colon},2})
214+ # # return a.a ⊗ a.b[I[1].b, I[2].b]
215+ # # end
216+ # # function Base.getindex(a::KroneckerEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
217+ # # return a.a[I[1].a, I[2].a] ⊗ a.b
218+ # # end
219+ # # function Base.getindex(a::EyeEye, I::Vararg{CartesianProduct{Colon,Colon},2})
220+ # # return a
221+ # # end
222+
223+ # # using FillArrays: OnesVector
224+ # # const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
225+ # # const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
226+ # # const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
227+
228+ axis (a) = only (axes (a))
229+
230+ # # # Convert indices determined with a generic call to `findtruncated` to indices
231+ # # # more suited for a KroneckerVector.
232+ # # function to_truncated_indices(values::OnesKroneckerVector, I)
233+ # # prods = cartesianproduct(axis(values))[I]
234+ # # I_id = only(to_indices(arg1(values), (:,)))
235+ # # I_data = unique(arg2.(prods))
236+ # # # Drop truncations that occur within the identity.
237+ # # I_data = filter(I_data) do i
238+ # # return count(x -> arg2(x) == i, prods) == length(arg2(values))
239+ # # end
240+ # # return I_id × I_data
241+ # # end
242+ # # function to_truncated_indices(values::KroneckerOnesVector, I)
243+ # # #I = findtruncated(Vector(values), strategy.strategy)
244+ # # prods = cartesianproduct(axis(values))[I]
245+ # # I_data = unique(arg1.(prods))
246+ # # # Drop truncations that occur within the identity.
247+ # # I_data = filter(I_data) do i
248+ # # return count(x -> arg1(x) == i, prods) == length(arg2(values))
249+ # # end
250+ # # I_id = only(to_indices(arg2(values), (:,)))
251+ # # return I_data × I_id
252+ # # end
253+ function to_truncated_indices (values:: KroneckerVector , I)
254+ return throw (ArgumentError (" Not implemented" ))
255+ end
256+
257+ function MatrixAlgebraKit. findtruncated (
258+ values:: KroneckerVector , strategy:: KroneckerTruncationStrategy
259+ )
260+ I = findtruncated (Vector (values), strategy. strategy)
261+ return to_truncated_indices (values, I)
262+ end
263+
264+ for f in [:eig_trunc! , :eigh_trunc! ]
265+ @eval begin
266+ function MatrixAlgebraKit. truncate! (
267+ :: typeof ($ f), DV:: NTuple{2,KroneckerMatrix} , strategy:: TruncationStrategy
268+ )
269+ return truncate! ($ f, DV, KroneckerTruncationStrategy (strategy))
270+ end
271+ function MatrixAlgebraKit. truncate! (
272+ :: typeof ($ f), (D, V):: NTuple{2,KroneckerMatrix} , strategy:: KroneckerTruncationStrategy
273+ )
274+ I = findtruncated (diagview (D), strategy)
275+ return (D[I, I], V[(:) × (:), I])
195276 end
196277 end
197278end
279+
280+ function MatrixAlgebraKit. truncate! (
281+ f:: typeof (svd_trunc!), USVᴴ:: NTuple{3,KroneckerMatrix} , strategy:: TruncationStrategy
282+ )
283+ return truncate! (f, USVᴴ, KroneckerTruncationStrategy (strategy))
284+ end
285+ function MatrixAlgebraKit. truncate! (
286+ :: typeof (svd_trunc!),
287+ (U, S, Vᴴ):: NTuple{3,KroneckerMatrix} ,
288+ strategy:: KroneckerTruncationStrategy ,
289+ )
290+ I = findtruncated (diagview (S), strategy)
291+ return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
292+ end
0 commit comments