Skip to content

Commit 75fff94

Browse files
authored
Improve overloading of orthnull functions (#59)
1 parent 03f2255 commit 75fff94

File tree

2 files changed

+59
-87
lines changed

2 files changed

+59
-87
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
authors = ["ITensor developers <[email protected]> and contributors"]
55

66
[deps]

src/matrixalgebrakit.jl

Lines changed: 58 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ using MatrixAlgebraKit:
2020
using MatrixAlgebraKit: TruncationStrategy, findtruncated, truncate
2121
import MatrixAlgebraKit as MAK
2222

23-
DiagonalArrays.diagview(a::AbstractKroneckerMatrix) = (DiagonalArrays.diagview.(kroneckerfactors(a))...)
23+
function DiagonalArrays.diagview(a::AbstractKroneckerMatrix)
24+
return (DiagonalArrays.diagview.(kroneckerfactors(a))...)
25+
end
2426
MatrixAlgebraKit.diagview(a::AbstractKroneckerMatrix) = DiagonalArrays.diagview(a)
2527

2628
struct KroneckerAlgorithm{A, B} <: AbstractAlgorithm
@@ -51,8 +53,10 @@ for f in (
5153
:default_lq_algorithm, :default_qr_algorithm,
5254
:default_polar_algorithm, :default_svd_algorithm,
5355
)
54-
@eval function MAK.$f(A::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...)
55-
A, B = kroneckerfactortypes(A)
56+
@eval function MAK.$f(
57+
AB::Type{<:AbstractKroneckerMatrix}; kwargs1 = (;), kwargs2 = (;), kwargs...
58+
)
59+
A, B = kroneckerfactortypes(AB)
5660
return KroneckerAlgorithm(
5761
MAK.$f(A; kwargs..., kwargs1...),
5862
MAK.$f(B; kwargs..., kwargs2...)
@@ -68,7 +72,11 @@ for f in (
6872
:svd_compact, :svd_full,
6973
)
7074
f! = Symbol(f, :!)
71-
@eval MAK.initialize_output(::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm) = nothing
75+
@eval function MAK.initialize_output(
76+
::typeof($f!), a::AbstractMatrix, alg::KroneckerAlgorithm
77+
)
78+
return nothing
79+
end
7280
@eval function MAK.$f!(ab::AbstractKroneckerMatrix, F, alg::KroneckerAlgorithm)
7381
a, b = kroneckerfactors(ab)
7482
algA, algB = kroneckerfactors(alg)
@@ -92,111 +100,66 @@ for f in (:eig_vals, :eigh_vals, :svd_vals)
92100
end
93101
end
94102

95-
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
96-
# is merged.
97-
for kind in ("polar", "qr", "svd")
103+
for f! in (:left_orth!, :right_orth!, :left_null!, :right_null!)
98104
@eval begin
99-
function MAK.initialize_output(
100-
::typeof(left_orth!), a::AbstractKroneckerMatrix,
101-
alg::MAK.LeftOrthAlgorithm{Symbol($kind)},
105+
function MAK.default_algorithm(
106+
::typeof($f!), AB::Type{<:AbstractKroneckerMatrix}; kwargs...
102107
)
103-
return nothing
108+
A, B = kroneckerfactortypes(AB)
109+
algA = MAK.default_algorithm($f!, A; kwargs...)
110+
algB = MAK.default_algorithm($f!, B; kwargs...)
111+
return KroneckerAlgorithm(algA, algB)
104112
end
105-
function MAK.left_orth!(
106-
ab::AbstractKroneckerMatrix, F, alg::MAK.LeftOrthAlgorithm{Symbol($kind)};
107-
kwargs1 = (;), kwargs2 = (;), kwargs...,
113+
function MAK.select_algorithm(
114+
::typeof($f!), ab::AbstractKroneckerMatrix, alg::Symbol; kwargs...
108115
)
109116
a, b = kroneckerfactors(ab)
110-
Fa = MAK.left_orth!(a; kwargs..., kwargs1...)
111-
Fb = MAK.left_orth!(b; kwargs..., kwargs2...)
112-
return Fa .⊗ Fb
117+
algA = MAK.select_algorithm($f!, a, alg; kwargs...)
118+
algB = MAK.select_algorithm($f!, b, alg; kwargs...)
119+
return KroneckerAlgorithm(algA, algB)
113120
end
121+
MAK.initialize_output(::typeof($f!), A, alg::KroneckerAlgorithm) = nothing
114122
end
115123
end
116-
117-
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
118-
# is merged.
119-
for kind in ("lq", "polar", "svd")
124+
for f! in (:left_orth!, :right_orth!)
120125
@eval begin
121-
function MAK.initialize_output(
122-
::typeof(right_orth!), a::AbstractKroneckerMatrix,
123-
alg::MAK.RightOrthAlgorithm{Symbol($kind)},
124-
)
125-
return nothing
126-
end
127-
function MAK.right_orth!(
128-
ab::AbstractKroneckerMatrix, F, alg::MAK.RightOrthAlgorithm{Symbol($kind)};
129-
kwargs1 = (;), kwargs2 = (;), kwargs...,
126+
function MAK.$f!(
127+
ab, F, alg::KroneckerAlgorithm; kwargs1 = (;), kwargs2 = (;), kwargs...,
130128
)
131129
a, b = kroneckerfactors(ab)
132-
Fa = MAK.right_orth!(a; kwargs..., kwargs1...)
133-
Fb = MAK.right_orth!(b; kwargs..., kwargs2...)
130+
Fa = MAK.$f!(a; kwargs..., kwargs1...)
131+
Fb = MAK.$f!(b; kwargs..., kwargs2...)
134132
return Fa .⊗ Fb
135133
end
136134
end
137135
end
138-
139-
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
140-
# is merged.
141-
for Alg in (
142-
:(MAK.LeftNullViaQR),
143-
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm}),
144-
:(MAK.LeftNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
145-
)
146-
@eval begin
147-
function MAK.initialize_output(
148-
::typeof(left_null!), a::AbstractKroneckerMatrix, alg::$Alg
149-
)
150-
return nothing
151-
end
152-
function MAK.left_null!(
153-
ab::AbstractKroneckerMatrix, F, alg::$Alg;
154-
kwargs1 = (;), kwargs2 = (;), kwargs...,
155-
)
156-
a, b = kroneckerfactors(ab)
157-
Na = MAK.left_null!(a; kwargs..., kwargs1...)
158-
Nb = MAK.left_null!(b; kwargs..., kwargs2...)
159-
return Na Nb
160-
end
161-
end
162-
end
163-
164-
# TODO: Delete this loop once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/104
165-
# is merged.
166-
for Alg in (
167-
:(MAK.RightNullViaLQ),
168-
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm}),
169-
:(MAK.RightNullViaSVD{<:MAK.TruncatedAlgorithm{<:MAK.GPU_Randomized}}),
170-
)
136+
for f! in (:left_null!, :right_null!)
171137
@eval begin
172-
function MAK.initialize_output(
173-
::typeof(right_null!), a::AbstractKroneckerMatrix, alg::$Alg
174-
)
175-
return nothing
176-
end
177-
function MAK.right_null!(
178-
ab::AbstractKroneckerMatrix, F, alg::$Alg;
138+
function MAK.$f!(
139+
ab, F, alg::KroneckerAlgorithm;
179140
kwargs1 = (;), kwargs2 = (;), kwargs...,
180141
)
181142
a, b = kroneckerfactors(ab)
182-
Na = MAK.right_null!(a; kwargs..., kwargs1...)
183-
Nb = MAK.right_null!(b; kwargs..., kwargs2...)
184-
return Na Nb
143+
Fa = MAK.$f!(a; kwargs..., kwargs1...)
144+
Fb = MAK.$f!(b; kwargs..., kwargs2...)
145+
return Fa Fb
185146
end
186147
end
187148
end
188149

189150
# Truncation
190151

191-
192152
struct KroneckerTruncationStrategy{T <: TruncationStrategy} <: TruncationStrategy
193153
strategy::T
194154
end
195155

196156
using FillArrays: OnesVector
197-
const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} = KroneckerVector{T, A, B}
198-
const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B}
199-
const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} = KroneckerVector{T, A, B}
157+
const OnesKroneckerVector{T, A <: OnesVector{T}, B <: AbstractVector{T}} =
158+
KroneckerVector{T, A, B}
159+
const KroneckerOnesVector{T, A <: AbstractVector{T}, B <: OnesVector{T}} =
160+
KroneckerVector{T, A, B}
161+
const OnesVectorOnesVector{T, A <: OnesVector{T}, B <: OnesVector{T}} =
162+
KroneckerVector{T, A, B}
200163

201164
axis(a) = only(axes(a))
202165

@@ -208,7 +171,8 @@ function to_truncated_indices(values::OnesKroneckerVector, I)
208171
I_data = unique(kroneckerfactors.(prods, 2))
209172
# Drop truncations that occur within the identity.
210173
I_data = filter(I_data) do i
211-
return count(x -> kroneckerfactors(x, 2) == i, prods) == length(kroneckerfactors(values, 2))
174+
return count(x -> kroneckerfactors(x, 2) == i, prods) ==
175+
length(kroneckerfactors(values, 2))
212176
end
213177
return I_id × I_data
214178
end
@@ -218,7 +182,8 @@ function to_truncated_indices(values::KroneckerOnesVector, I)
218182
I_data = unique(kroneckerfactors.(prods, 1))
219183
# Drop truncations that occur within the identity.
220184
I_data = filter(I_data) do i
221-
return count(x -> kroneckerfactors(x, 1) == i, prods) == length(kroneckerfactors(values, 2))
185+
return count(x -> kroneckerfactors(x, 1) == i, prods) ==
186+
length(kroneckerfactors(values, 2))
222187
end
223188
I_id = only(to_indices(kroneckerfactors(values, 2), (:,)))
224189
return I_data × I_id
@@ -240,22 +205,29 @@ end
240205

241206
for f in (:eig_trunc!, :eigh_trunc!)
242207
@eval function MAK.truncate(
243-
::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix}, strategy::TruncationStrategy
208+
::typeof($f), DV::NTuple{2, AbstractKroneckerMatrix},
209+
strategy::TruncationStrategy,
244210
)
245211
return MAK.truncate($f, DV, KroneckerTruncationStrategy(strategy))
246212
end
247213
@eval function MAK.truncate(
248-
::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy
214+
::typeof($f), (D, V)::NTuple{2, AbstractKroneckerMatrix},
215+
strategy::KroneckerTruncationStrategy,
249216
)
250217
I = MAK.findtruncated(MAK.diagview(D), strategy)
251218
return (D[I, I], V[(:) × (:), I]), I
252219
end
253220
end
254221

255-
MAK.truncate(f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix}, strategy::TruncationStrategy) =
256-
MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy))
257222
function MAK.truncate(
258-
::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix}, strategy::KroneckerTruncationStrategy,
223+
f::typeof(svd_trunc!), USVᴴ::NTuple{3, AbstractKroneckerMatrix},
224+
strategy::TruncationStrategy,
225+
)
226+
return MAK.truncate(f, USVᴴ, KroneckerTruncationStrategy(strategy))
227+
end
228+
function MAK.truncate(
229+
::typeof(svd_trunc!), (U, S, Vᴴ)::NTuple{3, AbstractKroneckerMatrix},
230+
strategy::KroneckerTruncationStrategy,
259231
)
260232
I = MAK.findtruncated(MAK.diagview(S), strategy)
261233
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]), I

0 commit comments

Comments
 (0)