Skip to content

Commit 255ab88

Browse files
authored
Add support for truncation, matrix functions (#9)
1 parent 6af509f commit 255ab88

File tree

5 files changed

+309
-23
lines changed

5 files changed

+309
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 176 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ end
1111
arguments(a::CartesianProduct) = (a.a, a.b)
1212
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
1313

14+
function Base.show(io::IO, a::CartesianProduct)
15+
print(io, a.a, " × ", a.b)
16+
return nothing
17+
end
18+
1419
×(a, b) = CartesianProduct(a, b)
1520
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
1621
Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
@@ -130,6 +135,8 @@ function interleave(x::Tuple, y::Tuple)
130135
xy = ntuple(i -> (x[i], y[i]), length(x))
131136
return flatten(xy)
132137
end
138+
# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing:
139+
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
133140
function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
134141
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
135142
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
@@ -183,6 +190,9 @@ function Base.getindex(a::KroneckerArray, i::Integer)
183190
return a[CartesianIndices(a)[i]]
184191
end
185192

193+
# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing
194+
# in the n-dimensional case and use it to replace the matrix and vector cases:
195+
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
186196
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
187197
return error("Not implemented.")
188198
end
@@ -222,6 +232,10 @@ end
222232
function Base.inv(a::KroneckerArray)
223233
return inv(a.a) inv(a.b)
224234
end
235+
using LinearAlgebra: LinearAlgebra, pinv
236+
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
237+
return pinv(a.a; kwargs...) pinv(a.b; kwargs...)
238+
end
225239
function Base.transpose(a::KroneckerArray)
226240
return transpose(a.a) transpose(a.b)
227241
end
@@ -297,6 +311,7 @@ using LinearAlgebra:
297311
Diagonal,
298312
Eigen,
299313
SVD,
314+
det,
300315
diag,
301316
eigen,
302317
eigvals,
@@ -335,9 +350,63 @@ end
335350
function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
336351
return norm(a.a, p) norm(a.b, p)
337352
end
353+
354+
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
355+
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
356+
return diagview(a.a) diagview(a.b)
357+
end
338358
function LinearAlgebra.diag(a::KroneckerArray)
339-
return diag(a.a) diag(a.b)
359+
return copy(diagview(a.a)) copy(diagview(a.b))
360+
end
361+
362+
# Matrix functions
363+
const MATRIX_FUNCTIONS = [
364+
:exp,
365+
:cis,
366+
:log,
367+
:sqrt,
368+
:cbrt,
369+
:cos,
370+
:sin,
371+
:tan,
372+
:csc,
373+
:sec,
374+
:cot,
375+
:cosh,
376+
:sinh,
377+
:tanh,
378+
:csch,
379+
:sech,
380+
:coth,
381+
:acos,
382+
:asin,
383+
:atan,
384+
:acsc,
385+
:asec,
386+
:acot,
387+
:acosh,
388+
:asinh,
389+
:atanh,
390+
:acsch,
391+
:asech,
392+
:acoth,
393+
]
394+
395+
for f in MATRIX_FUNCTIONS
396+
@eval begin
397+
function Base.$f(a::KroneckerArray)
398+
return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported."))
399+
end
400+
end
401+
end
402+
403+
using LinearAlgebra: checksquare
404+
function LinearAlgebra.det(a::KroneckerArray)
405+
checksquare(a.a)
406+
checksquare(a.b)
407+
return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1)
340408
end
409+
341410
function LinearAlgebra.svd(a::KroneckerArray)
342411
Fa = svd(a.a)
343412
Fb = svd(a.b)
@@ -690,18 +759,6 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
690759
end
691760
end
692761

693-
for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
694-
@eval begin
695-
function MatrixAlgebraKit.truncate!(
696-
::typeof($f),
697-
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
698-
strategy::TruncationStrategy,
699-
)
700-
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
701-
end
702-
end
703-
end
704-
705762
for f in [:left_orth!, :right_orth!]
706763
@eval begin
707764
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
@@ -941,4 +998,110 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
941998
end
942999
end
9431000

1001+
using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate!
1002+
1003+
struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
1004+
strategy::T
1005+
end
1006+
1007+
# Avoid instantiating the identity.
1008+
function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2})
1009+
return a.a a.b[I[1].b, I[2].b]
1010+
end
1011+
function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2})
1012+
return a.a[I[1].a, I[2].a] a.b
1013+
end
1014+
function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2})
1015+
return a
1016+
end
1017+
1018+
using FillArrays: OnesVector
1019+
const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B}
1020+
const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
1021+
const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B}
1022+
1023+
function MatrixAlgebraKit.findtruncated(
1024+
values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy
1025+
)
1026+
I = findtruncated(Vector(values), strategy.strategy)
1027+
prods = collect(only(axes(values)).product)[I]
1028+
I_data = unique(map(x -> x.a, prods))
1029+
# Drop truncations that occur within the identity.
1030+
I_data = filter(I_data) do i
1031+
return count(x -> x.a == i, prods) == length(values.a)
1032+
end
1033+
return (:) × I_data
1034+
end
1035+
function MatrixAlgebraKit.findtruncated(
1036+
values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy
1037+
)
1038+
I = findtruncated(Vector(values), strategy.strategy)
1039+
prods = collect(only(axes(values)).product)[I]
1040+
I_data = unique(map(x -> x.b, prods))
1041+
# Drop truncations that occur within the identity.
1042+
I_data = filter(I_data) do i
1043+
return count(x -> x.b == i, prods) == length(values.b)
1044+
end
1045+
return I_data × (:)
1046+
end
1047+
function MatrixAlgebraKit.findtruncated(
1048+
values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy
1049+
)
1050+
return throw(ArgumentError("Can't truncate Eye ⊗ Eye."))
1051+
end
1052+
1053+
for f in [:eig_trunc!, :eigh_trunc!]
1054+
@eval begin
1055+
function MatrixAlgebraKit.truncate!(
1056+
::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy
1057+
)
1058+
return truncate!($f, DV, KroneckerTruncationStrategy(strategy))
1059+
end
1060+
function MatrixAlgebraKit.truncate!(
1061+
::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy
1062+
)
1063+
I = findtruncated(diagview(D), strategy)
1064+
return (D[I, I], V[(:) × (:), I])
1065+
end
1066+
end
1067+
end
1068+
1069+
function MatrixAlgebraKit.truncate!(
1070+
f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy
1071+
)
1072+
return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy))
1073+
end
1074+
function MatrixAlgebraKit.truncate!(
1075+
::typeof(svd_trunc!),
1076+
(U, S, Vᴴ)::NTuple{3,KroneckerMatrix},
1077+
strategy::KroneckerTruncationStrategy,
1078+
)
1079+
I = findtruncated(diagview(S), strategy)
1080+
return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)])
1081+
end
1082+
1083+
for f in MATRIX_FUNCTIONS
1084+
@eval begin
1085+
function Base.$f(a::SquareEyeKronecker)
1086+
return a.a $f(a.b)
1087+
end
1088+
function Base.$f(a::KroneckerSquareEye)
1089+
return $f(a.a) a.b
1090+
end
1091+
function Base.$f(a::SquareEyeSquareEye)
1092+
return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported."))
1093+
end
1094+
end
1095+
end
1096+
1097+
function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...)
1098+
return a.a pinv(a.b; kwargs...)
1099+
end
1100+
function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...)
1101+
return pinv(a.a; kwargs...) a.b
1102+
end
1103+
function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...)
1104+
return a
1105+
end
1106+
9441107
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
55
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
77
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
89
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1011
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
@@ -16,6 +17,7 @@ KroneckerArrays = "0.1"
1617
LinearAlgebra = "1.10"
1718
MatrixAlgebraKit = "0.2"
1819
SafeTestsets = "0.1"
20+
StableRNGs = "1.0"
1921
Suppressor = "0.2"
2022
Test = "1.10"
2123
TestExtras = "0.3"

test/test_basics.jl

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using FillArrays: Eye
22
using KroneckerArrays: KroneckerArrays, , ×, diagonal, kron_nd
3-
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
4-
using Test: @test, @testset
3+
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr
4+
using StableRNGs: StableRNG
5+
using Test: @test, @test_broken, @test_throws, @testset
56

67
const elts = (Float32, Float64, ComplexF32, ComplexF64)
78
@testset "KroneckerArrays (eltype=$elt)" for elt in elts
@@ -35,7 +36,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
3536
@test iszero(a - a)
3637
@test collect(a + c) collect(a) + collect(c)
3738
@test collect(b + c) collect(b) + collect(c)
38-
for f in (transpose, adjoint, inv)
39+
for f in (transpose, adjoint, inv, pinv)
3940
@test collect(f(a)) f(collect(a))
4041
end
4142
@test tr(a) tr(collect(a))
@@ -66,9 +67,25 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
6667
Q, R = qr(a)
6768
@test collect(Q * R) collect(a)
6869
@test collect(Q'Q) I
70+
71+
a = randn(elt, 2, 2) randn(elt, 3, 3)
72+
@test det(a) det(collect(a))
73+
74+
a = randn(elt, 2, 2) randn(elt, 3, 3)
75+
for f in KroneckerArrays.MATRIX_FUNCTIONS
76+
@eval begin
77+
@test_throws ArgumentError $f($a)
78+
end
79+
end
6980
end
7081

7182
@testset "FillArrays.Eye" begin
83+
MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS
84+
if VERSION < v"1.11-"
85+
# `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11.
86+
MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt])
87+
end
88+
7289
a = Eye(2) randn(3, 3)
7390
@test size(a) == (6, 6)
7491
@test a + a == Eye(2) (2a.b)
@@ -80,4 +97,66 @@ end
8097
@test a + a == (2a.a) Eye(2)
8198
@test 2a == (2a.a) Eye(2)
8299
@test a * a == (a.a * a.a) Eye(2)
100+
101+
# Eye ⊗ A
102+
rng = StableRNG(123)
103+
a = Eye(2) randn(rng, 3, 3)
104+
for f in MATRIX_FUNCTIONS
105+
@eval begin
106+
fa = $f($a)
107+
@test collect(fa) $f(collect($a)) rtol = (eps(real(eltype($a))))
108+
@test fa.a isa Eye
109+
end
110+
end
111+
112+
fa = inv(a)
113+
@test collect(fa) inv(collect(a))
114+
@test fa.a isa Eye
115+
116+
fa = pinv(a)
117+
@test collect(fa) pinv(collect(a))
118+
@test fa.a isa Eye
119+
120+
@test det(a) det(collect(a))
121+
122+
# A ⊗ Eye
123+
rng = StableRNG(123)
124+
a = randn(rng, 3, 3) Eye(2)
125+
for f in setdiff(MATRIX_FUNCTIONS, [:atanh])
126+
@eval begin
127+
fa = $f($a)
128+
@test collect(fa) $f(collect($a)) rtol = (eps(real(eltype($a))))
129+
@test fa.b isa Eye
130+
end
131+
end
132+
133+
fa = inv(a)
134+
@test collect(fa) inv(collect(a))
135+
@test fa.b isa Eye
136+
137+
fa = pinv(a)
138+
@test collect(fa) pinv(collect(a))
139+
@test fa.b isa Eye
140+
141+
@test det(a) det(collect(a))
142+
143+
# Eye ⊗ Eye
144+
a = Eye(2) Eye(2)
145+
for f in KroneckerArrays.MATRIX_FUNCTIONS
146+
@eval begin
147+
@test_throws ArgumentError $f($a)
148+
end
149+
end
150+
151+
fa = inv(a)
152+
@test fa == a
153+
@test fa.a isa Eye
154+
@test fa.b isa Eye
155+
156+
fa = pinv(a)
157+
@test fa == a
158+
@test fa.a isa Eye
159+
@test fa.b isa Eye
160+
161+
@test det(a) det(collect(a)) 1
83162
end

0 commit comments

Comments
 (0)