@@ -20,7 +20,9 @@ using MatrixAlgebraKit:
2020using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate
2121import MatrixAlgebraKit as MAK
2222
23- DiagonalArrays. diagview (a:: AbstractKroneckerMatrix ) = ⊗ (DiagonalArrays. diagview .(kroneckerfactors (a))... )
23+ function DiagonalArrays. diagview (a:: AbstractKroneckerMatrix )
24+ return ⊗ (DiagonalArrays. diagview .(kroneckerfactors (a))... )
25+ end
2426MatrixAlgebraKit. diagview (a:: AbstractKroneckerMatrix ) = DiagonalArrays. diagview (a)
2527
2628struct KroneckerAlgorithm{A, B} <: AbstractAlgorithm
@@ -51,8 +53,10 @@ for f in (
5153 :default_lq_algorithm , :default_qr_algorithm ,
5254 :default_polar_algorithm , :default_svd_algorithm ,
5355 )
54- @eval function MAK. $f (A:: Type{<:AbstractKroneckerMatrix} ; kwargs1 = (;), kwargs2 = (;), kwargs... )
55- A, B = kroneckerfactortypes (A)
56+ @eval function MAK. $f (
57+ AB:: Type{<:AbstractKroneckerMatrix} ; kwargs1 = (;), kwargs2 = (;), kwargs...
58+ )
59+ A, B = kroneckerfactortypes (AB)
5660 return KroneckerAlgorithm (
5761 MAK.$ f (A; kwargs... , kwargs1... ),
5862 MAK.$ f (B; kwargs... , kwargs2... )
@@ -68,7 +72,11 @@ for f in (
6872 :svd_compact , :svd_full ,
6973 )
7074 f! = Symbol (f, :! )
71- @eval MAK. initialize_output (:: typeof ($ f!), a:: AbstractMatrix , alg:: KroneckerAlgorithm ) = nothing
75+ @eval function MAK. initialize_output (
76+ :: typeof ($ f!), a:: AbstractMatrix , alg:: KroneckerAlgorithm
77+ )
78+ return nothing
79+ end
7280 @eval function MAK. $f! (ab:: AbstractKroneckerMatrix , F, alg:: KroneckerAlgorithm )
7381 a, b = kroneckerfactors (ab)
7482 algA, algB = kroneckerfactors (alg)
@@ -92,111 +100,66 @@ for f in (:eig_vals, :eigh_vals, :svd_vals)
92100 end
93101end
94102
95- # TODO : Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
96- # is merged.
97- for kind in (" polar" , " qr" , " svd" )
103+ for f! in (:left_orth! , :right_orth! , :left_null! , :right_null! )
98104 @eval begin
99- function MAK. initialize_output (
100- :: typeof (left_orth!), a:: AbstractKroneckerMatrix ,
101- alg:: MAK.LeftOrthAlgorithm{Symbol($kind)} ,
105+ function MAK. default_algorithm (
106+ :: typeof ($ f!), AB:: Type{<:AbstractKroneckerMatrix} ; kwargs...
102107 )
103- return nothing
108+ A, B = kroneckerfactortypes (AB)
109+ algA = MAK. default_algorithm ($ f!, A; kwargs... )
110+ algB = MAK. default_algorithm ($ f!, B; kwargs... )
111+ return KroneckerAlgorithm (algA, algB)
104112 end
105- function MAK. left_orth! (
106- ab:: AbstractKroneckerMatrix , F, alg:: MAK.LeftOrthAlgorithm{Symbol($kind)} ;
107- kwargs1 = (;), kwargs2 = (;), kwargs... ,
113+ function MAK. select_algorithm (
114+ :: typeof ($ f!), ab:: AbstractKroneckerMatrix , alg:: Symbol ; kwargs...
108115 )
109116 a, b = kroneckerfactors (ab)
110- Fa = MAK. left_orth! (a ; kwargs... , kwargs1 ... )
111- Fb = MAK. left_orth! (b ; kwargs... , kwargs2 ... )
112- return Fa .⊗ Fb
117+ algA = MAK. select_algorithm ( $ f!, a, alg ; kwargs... )
118+ algB = MAK. select_algorithm ( $ f!, b, alg ; kwargs... )
119+ return KroneckerAlgorithm (algA, algB)
113120 end
121+ MAK. initialize_output (:: typeof ($ f!), A, alg:: KroneckerAlgorithm ) = nothing
114122 end
115123end
116-
117- # TODO : Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
118- # is merged.
119- for kind in (" lq" , " polar" , " svd" )
124+ for f! in (:left_orth! , :right_orth! )
120125 @eval begin
121- function MAK. initialize_output (
122- :: typeof (right_orth!), a:: AbstractKroneckerMatrix ,
123- alg:: MAK.RightOrthAlgorithm{Symbol($kind)} ,
124- )
125- return nothing
126- end
127- function MAK. right_orth! (
128- ab:: AbstractKroneckerMatrix , F, alg:: MAK.RightOrthAlgorithm{Symbol($kind)} ;
129- kwargs1 = (;), kwargs2 = (;), kwargs... ,
126+ function MAK. $f! (
127+ ab, F, alg:: KroneckerAlgorithm ; kwargs1 = (;), kwargs2 = (;), kwargs... ,
130128 )
131129 a, b = kroneckerfactors (ab)
132- Fa = MAK. right_orth ! (a; kwargs... , kwargs1... )
133- Fb = MAK. right_orth ! (b; kwargs... , kwargs2... )
130+ Fa = MAK.$ f ! (a; kwargs... , kwargs1... )
131+ Fb = MAK.$ f ! (b; kwargs... , kwargs2... )
134132 return Fa .⊗ Fb
135133 end
136134 end
137135end
138-
139- # TODO : Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
140- # is merged.
141- for Alg in (
142- :(MAK. LeftNullViaQR),
143- :(MAK. LeftNullViaSVD{<: MAK.TruncatedAlgorithm }),
144- :(MAK. LeftNullViaSVD{<: MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized} }),
145- )
146- @eval begin
147- function MAK. initialize_output (
148- :: typeof (left_null!), a:: AbstractKroneckerMatrix , alg:: $Alg
149- )
150- return nothing
151- end
152- function MAK. left_null! (
153- ab:: AbstractKroneckerMatrix , F, alg:: $Alg ;
154- kwargs1 = (;), kwargs2 = (;), kwargs... ,
155- )
156- a, b = kroneckerfactors (ab)
157- Na = MAK. left_null! (a; kwargs... , kwargs1... )
158- Nb = MAK. left_null! (b; kwargs... , kwargs2... )
159- return Na ⊗ Nb
160- end
161- end
162- end
163-
164- # TODO : Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
165- # is merged.
166- for Alg in (
167- :(MAK. RightNullViaLQ),
168- :(MAK. RightNullViaSVD{<: MAK.TruncatedAlgorithm }),
169- :(MAK. RightNullViaSVD{<: MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized} }),
170- )
136+ for f! in (:left_null! , :right_null! )
171137 @eval begin
172- function MAK. initialize_output (
173- :: typeof (right_null!), a:: AbstractKroneckerMatrix , alg:: $Alg
174- )
175- return nothing
176- end
177- function MAK. right_null! (
178- ab:: AbstractKroneckerMatrix , F, alg:: $Alg ;
138+ function MAK. $f! (
139+ ab, F, alg:: KroneckerAlgorithm ;
179140 kwargs1 = (;), kwargs2 = (;), kwargs... ,
180141 )
181142 a, b = kroneckerfactors (ab)
182- Na = MAK. right_null ! (a; kwargs... , kwargs1... )
183- Nb = MAK. right_null ! (b; kwargs... , kwargs2... )
184- return Na ⊗ Nb
143+ Fa = MAK.$ f ! (a; kwargs... , kwargs1... )
144+ Fb = MAK.$ f ! (b; kwargs... , kwargs2... )
145+ return Fa ⊗ Fb
185146 end
186147 end
187148end
188149
189150# Truncation
190151
191-
192152struct KroneckerTruncationStrategy{T <: TruncationStrategy } <: TruncationStrategy
193153 strategy:: T
194154end
195155
196156using FillArrays: OnesVector
197- const OnesKroneckerVector{T, A <: OnesVector{T} , B <: AbstractVector{T} } = KroneckerVector{T, A, B}
198- const KroneckerOnesVector{T, A <: AbstractVector{T} , B <: OnesVector{T} } = KroneckerVector{T, A, B}
199- const OnesVectorOnesVector{T, A <: OnesVector{T} , B <: OnesVector{T} } = KroneckerVector{T, A, B}
157+ const OnesKroneckerVector{T, A <: OnesVector{T} , B <: AbstractVector{T} } =
158+ KroneckerVector{T, A, B}
159+ const KroneckerOnesVector{T, A <: AbstractVector{T} , B <: OnesVector{T} } =
160+ KroneckerVector{T, A, B}
161+ const OnesVectorOnesVector{T, A <: OnesVector{T} , B <: OnesVector{T} } =
162+ KroneckerVector{T, A, B}
200163
201164axis (a) = only (axes (a))
202165
@@ -208,7 +171,8 @@ function to_truncated_indices(values::OnesKroneckerVector, I)
208171 I_data = unique (kroneckerfactors .(prods, 2 ))
209172 # Drop truncations that occur within the identity.
210173 I_data = filter (I_data) do i
211- return count (x -> kroneckerfactors (x, 2 ) == i, prods) == length (kroneckerfactors (values, 2 ))
174+ return count (x -> kroneckerfactors (x, 2 ) == i, prods) ==
175+ length (kroneckerfactors (values, 2 ))
212176 end
213177 return I_id × I_data
214178end
@@ -218,7 +182,8 @@ function to_truncated_indices(values::KroneckerOnesVector, I)
218182 I_data = unique (kroneckerfactors .(prods, 1 ))
219183 # Drop truncations that occur within the identity.
220184 I_data = filter (I_data) do i
221- return count (x -> kroneckerfactors (x, 1 ) == i, prods) == length (kroneckerfactors (values, 2 ))
185+ return count (x -> kroneckerfactors (x, 1 ) == i, prods) ==
186+ length (kroneckerfactors (values, 2 ))
222187 end
223188 I_id = only (to_indices (kroneckerfactors (values, 2 ), (:,)))
224189 return I_data × I_id
@@ -240,22 +205,29 @@ end
240205
241206for f in (:eig_trunc! , :eigh_trunc! )
242207 @eval function MAK. truncate (
243- :: typeof ($ f), DV:: NTuple{2, AbstractKroneckerMatrix} , strategy:: TruncationStrategy
208+ :: typeof ($ f), DV:: NTuple{2, AbstractKroneckerMatrix} ,
209+ strategy:: TruncationStrategy ,
244210 )
245211 return MAK. truncate ($ f, DV, KroneckerTruncationStrategy (strategy))
246212 end
247213 @eval function MAK. truncate (
248- :: typeof ($ f), (D, V):: NTuple{2, AbstractKroneckerMatrix} , strategy:: KroneckerTruncationStrategy
214+ :: typeof ($ f), (D, V):: NTuple{2, AbstractKroneckerMatrix} ,
215+ strategy:: KroneckerTruncationStrategy ,
249216 )
250217 I = MAK. findtruncated (MAK. diagview (D), strategy)
251218 return (D[I, I], V[(:) × (:), I]), I
252219 end
253220end
254221
255- MAK. truncate (f:: typeof (svd_trunc!), USVᴴ:: NTuple{3, AbstractKroneckerMatrix} , strategy:: TruncationStrategy ) =
256- MAK. truncate (f, USVᴴ, KroneckerTruncationStrategy (strategy))
257222function MAK. truncate (
258- :: typeof (svd_trunc!), (U, S, Vᴴ):: NTuple{3, AbstractKroneckerMatrix} , strategy:: KroneckerTruncationStrategy ,
223+ f:: typeof (svd_trunc!), USVᴴ:: NTuple{3, AbstractKroneckerMatrix} ,
224+ strategy:: TruncationStrategy ,
225+ )
226+ return MAK. truncate (f, USVᴴ, KroneckerTruncationStrategy (strategy))
227+ end
228+ function MAK. truncate (
229+ :: typeof (svd_trunc!), (U, S, Vᴴ):: NTuple{3, AbstractKroneckerMatrix} ,
230+ strategy:: KroneckerTruncationStrategy ,
259231 )
260232 I = MAK. findtruncated (MAK. diagview (S), strategy)
261233 return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]), I
0 commit comments