Skip to content

Commit 35233cb

Browse files
committed
Add support for truncation
1 parent 6af509f commit 35233cb

File tree

4 files changed

+298
-23
lines changed

4 files changed

+298
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 176 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ end
1111
arguments(a::CartesianProduct) = (a.a, a.b)
1212
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
1313

14+
function Base.show(io::IO, a::CartesianProduct)
15+
print(io, a.a, " × ", a.b)
16+
return nothing
17+
end
18+
1419
×(a, b) = CartesianProduct(a, b)
1520
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
1621
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
@@ -130,6 +135,8 @@ function interleave(x::Tuple, y::Tuple)
130135
xy = ntuple(i -> (x[i], y[i]), length(x))
131136
return flatten(xy)
132137
end
138+
# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing:
139+
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
133140
function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
134141
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
135142
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
@@ -183,6 +190,9 @@ function Base.getindex(a::KroneckerArray, i::Integer)
183190
return a[CartesianIndices(a)[i]]
184191
end
185192

193+
# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing
194+
# in the n-dimensional case and use it to replace the matrix and vector cases:
195+
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
186196
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
187197
return error("Not implemented.")
188198
end
@@ -222,6 +232,10 @@ end
222232
function Base.inv(a::KroneckerArray)
223233
return inv(a.a) inv(a.b)
224234
end
235+
using LinearAlgebra: LinearAlgebra, pinv
236+
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
237+
return pinv(a.a; kwargs...) pinv(a.b; kwargs...)
238+
end
225239
function Base.transpose(a::KroneckerArray)
226240
return transpose(a.a) transpose(a.b)
227241
end
@@ -297,6 +311,7 @@ using LinearAlgebra:
297311
Diagonal,
298312
Eigen,
299313
SVD,
314+
det,
300315
diag,
301316
eigen,
302317
eigvals,
@@ -335,9 +350,63 @@ end
335350
function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
336351
return norm(a.a, p) norm(a.b, p)
337352
end
353+
354+
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
355+
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
356+
return diagview(a.a) diagview(a.b)
357+
end
338358
function LinearAlgebra.diag(a::KroneckerArray)
339-
return diag(a.a) diag(a.b)
359+
return copy(diagview(a.a)) copy(diagview(a.b))
360+
end
361+
362+
# Matrix functions
363+
const MATRIX_FUNCTIONS = [
364+
:exp,
365+
:cis,
366+
:log,
367+
:sqrt,
368+
:cbrt,
369+
:cos,
370+
:sin,
371+
:tan,
372+
:csc,
373+
:sec,
374+
:cot,
375+
:cosh,
376+
:sinh,
377+
:tanh,
378+
:csch,
379+
:sech,
380+
:coth,
381+
:acos,
382+
:asin,
383+
:atan,
384+
:acsc,
385+
:asec,
386+
:acot,
387+
:acosh,
388+
:asinh,
389+
:atanh,
390+
:acsch,
391+
:asech,
392+
:acoth,
393+
]
394+
395+
for f in MATRIX_FUNCTIONS
396+
@eval begin
397+
function Base.$f(a::KroneckerArray)
398+
return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported."))
399+
end
400+
end
401+
end
402+
403+
using LinearAlgebra: checksquare
404+
function LinearAlgebra.det(a::KroneckerArray)
405+
checksquare(a.a)
406+
checksquare(a.b)
407+
return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1)
340408
end
409+
341410
function LinearAlgebra.svd(a::KroneckerArray)
342411
Fa = svd(a.a)
343412
Fb = svd(a.b)
@@ -690,18 +759,6 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
690759
end
691760
end
692761

693-
for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
694-
@eval begin
695-
function MatrixAlgebraKit.truncate!(
696-
::typeof($f),
697-
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
698-
strategy::TruncationStrategy,
699-
)
700-
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
701-
end
702-
end
703-
end
704-
705762
for f in [:left_orth!, :right_orth!]
706763
@eval begin
707764
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
@@ -941,4 +998,110 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
941998
end
942999
end
9431000

1001+
using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate!
1002+
1003+
struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
1004+
strategy::T
1005+
end
1006+
1007+
# Avoid instantiating the identity.
1008+
function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2})
1009+
return a.a a.b[I[1].b, I[2].b]
1010+
end
1011+
function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
1012+
return a.a[I[1].a, I[2].a] a.b
1013+
end
1014+
function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2})
1015+
return a
1016+
end
1017+
1018+
using FillArrays: OnesVector
1019+
const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
1020+
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
1021+
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
1022+
1023+
function MatrixAlgebraKit.findtruncated(
1024+
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
1025+
)
1026+
I = findtruncated(Vector(values), strategy.strategy)
1027+
prods = collect(only(axes(values)).product)[I]
1028+
I_data = unique(map(x -> x.a, prods))
1029+
# Drop truncations that occur within the identity.
1030+
I_data = filter(I_data) do i
1031+
return count(x -> x.a == i, prods) == length(values.a)
1032+
end
1033+
return (:) × I_data
1034+
end
1035+
function MatrixAlgebraKit.findtruncated(
1036+
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
1037+
)
1038+
I = findtruncated(Vector(values), strategy.strategy)
1039+
prods = collect(only(axes(values)).product)[I]
1040+
I_data = unique(map(x -> x.b, prods))
1041+
# Drop truncations that occur within the identity.
1042+
I_data = filter(I_data) do i
1043+
return count(x -> x.b == i, prods) == length(values.b)
1044+
end
1045+
return I_data × (:)
1046+
end
1047+
function MatrixAlgebraKit.findtruncated(
1048+
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
1049+
)
1050+
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
1051+
end
1052+
1053+
for f in [:eig_trunc!, :eigh_trunc!]
1054+
@eval begin
1055+
function MatrixAlgebraKit.truncate!(
1056+
::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy
1057+
)
1058+
return truncate!($f, DV, KroneckerTruncationStrategy(strategy))
1059+
end
1060+
function MatrixAlgebraKit.truncate!(
1061+
::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy
1062+
)
1063+
I = findtruncated(diagview(D), strategy)
1064+
return (D[I, I], V[(:) × (:), I])
1065+
end
1066+
end
1067+
end
1068+
1069+
function MatrixAlgebraKit.truncate!(
1070+
f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy
1071+
)
1072+
return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy))
1073+
end
1074+
function MatrixAlgebraKit.truncate!(
1075+
::typeof(svd_trunc!),
1076+
(U, S, Vᴴ)::NTuple{3,KroneckerMatrix},
1077+
strategy::KroneckerTruncationStrategy,
1078+
)
1079+
I = findtruncated(diagview(S), strategy)
1080+
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
1081+
end
1082+
1083+
for f in MATRIX_FUNCTIONS
1084+
@eval begin
1085+
function Base.$f(a::SquareEyeKronecker)
1086+
return a.a $f(a.b)
1087+
end
1088+
function Base.$f(a::KroneckerSquareEye)
1089+
return $f(a.a) a.b
1090+
end
1091+
function Base.$f(a::SquareEyeSquareEye)
1092+
return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported."))
1093+
end
1094+
end
1095+
end
1096+
1097+
function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...)
1098+
return a.a pinv(a.b; kwargs...)
1099+
end
1100+
function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...)
1101+
return pinv(a.a; kwargs...) a.b
1102+
end
1103+
function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...)
1104+
return a
1105+
end
1106+
9441107
end

test/test_basics.jl

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using FillArrays: Eye
22
using KroneckerArrays: KroneckerArrays, , ×, diagonal, kron_nd
3-
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
4-
using Test: @test, @testset
3+
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr
4+
using Test: @test, @test_broken, @test_throws, @testset
55

66
const elts = (Float32, Float64, ComplexF32, ComplexF64)
77
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
@@ -35,7 +35,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
3535
@test iszero(a - a)
3636
@test collect(a + c) collect(a) + collect(c)
3737
@test collect(b + c) collect(b) + collect(c)
38-
for f in (transpose, adjoint, inv)
38+
for f in (transpose, adjoint, inv, pinv)
3939
@test collect(f(a)) f(collect(a))
4040
end
4141
@test tr(a) tr(collect(a))
@@ -66,6 +66,16 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
6666
Q, R = qr(a)
6767
@test collect(Q * R) collect(a)
6868
@test collect(Q'Q) I
69+
70+
a = randn(elt, 2, 2) randn(elt, 3, 3)
71+
@test det(a) det(collect(a))
72+
73+
a = randn(elt, 2, 2) randn(elt, 3, 3)
74+
for f in KroneckerArrays.MATRIX_FUNCTIONS
75+
@eval begin
76+
@test_throws ArgumentError $f($a)
77+
end
78+
end
6979
end
7080

7181
@testset "FillArrays.Eye" begin
@@ -80,4 +90,64 @@ end
8090
@test a + a == (2a.a) Eye(2)
8191
@test 2a == (2a.a) Eye(2)
8292
@test a * a == (a.a * a.a) Eye(2)
93+
94+
# Eye ⊗ A
95+
a = Eye(2) randn(3, 3)
96+
for f in KroneckerArrays.MATRIX_FUNCTIONS
97+
@eval begin
98+
fa = $f($a)
99+
@test collect(fa) $f(collect($a))
100+
@test fa.a isa Eye
101+
end
102+
end
103+
104+
fa = inv(a)
105+
@test collect(fa) inv(collect(a))
106+
@test fa.a isa Eye
107+
108+
fa = pinv(a)
109+
@test collect(fa) pinv(collect(a))
110+
@test fa.a isa Eye
111+
112+
@test det(a) det(collect(a))
113+
114+
# A ⊗ Eye
115+
a = randn(3, 3) Eye(2)
116+
for f in KroneckerArrays.MATRIX_FUNCTIONS
117+
@eval begin
118+
fa = $f($a)
119+
@test collect(fa) $f(collect($a))
120+
@test fa.b isa Eye
121+
end
122+
end
123+
124+
fa = inv(a)
125+
@test collect(fa) inv(collect(a))
126+
@test fa.b isa Eye
127+
128+
fa = pinv(a)
129+
@test collect(fa) pinv(collect(a))
130+
@test fa.b isa Eye
131+
132+
@test det(a) det(collect(a))
133+
134+
# Eye ⊗ Eye
135+
a = Eye(2) Eye(2)
136+
for f in KroneckerArrays.MATRIX_FUNCTIONS
137+
@eval begin
138+
@test_throws ArgumentError $f($a)
139+
end
140+
end
141+
142+
fa = inv(a)
143+
@test fa == a
144+
@test fa.a isa Eye
145+
@test fa.b isa Eye
146+
147+
fa = pinv(a)
148+
@test fa == a
149+
@test fa.a isa Eye
150+
@test fa.b isa Eye
151+
152+
@test det(a) det(collect(a)) 1
83153
end

0 commit comments

Comments
 (0)