Skip to content

Commit d3de597

Browse files
authored
Generalize beyond matrices (#5)
1 parent 1e854f3 commit d3de597

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"

src/KroneckerArrays.jl

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,29 @@ function Base.similar(
118118
return similar(promote_type(A, B), sz)
119119
end
120120

121-
Base.collect(a::KroneckerArray) = kron(a.a, a.b)
121+
function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
122+
return (t[1]..., flatten(Base.tail(t))...)
123+
end
124+
function flatten(t::Tuple{Tuple})
125+
return t[1]
126+
end
127+
flatten(::Tuple{}) = ()
128+
function interleave(x::Tuple, y::Tuple)
129+
length(x) == length(y) || throw(ArgumentError("Tuples must have the same length."))
130+
xy = ntuple(i -> (x[i], y[i]), length(x))
131+
return flatten(xy)
132+
end
133+
function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
134+
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
135+
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
136+
c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N)))
137+
sz = ntuple(i -> size(a, i) * size(b, i), N)
138+
return permutedims(reshape(c′, sz), reverse(ntuple(identity, N)))
139+
end
140+
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
141+
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
142+
143+
Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b)
122144

123145
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
124146
return convert(Array{T,N}, collect(a))
@@ -150,10 +172,18 @@ function Base.show(io::IO, a::KroneckerArray)
150172
return nothing
151173
end
152174

153-
(a::AbstractVecOrMat, b::AbstractVecOrMat) = KroneckerArray(a, b)
175+
(a::AbstractArray, b::AbstractArray) = KroneckerArray(a, b)
154176
(a::Number, b::Number) = a * b
155-
(a::Number, b::AbstractVecOrMat) = a * b
156-
(a::AbstractVecOrMat, b::Number) = a * b
177+
(a::Number, b::AbstractArray) = a * b
178+
(a::AbstractArray, b::Number) = a * b
179+
180+
function Base.getindex(a::KroneckerArray, i::Integer)
181+
return a[CartesianIndices(a)[i]]
182+
end
183+
184+
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
185+
return error("Not implemented.")
186+
end
157187

158188
function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
159189
GPUArraysCore.assertscalar("getindex")
@@ -162,22 +192,21 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
162192
k, l = size(a.b)
163193
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
164194
end
165-
function Base.getindex(a::KroneckerMatrix, i::Integer)
166-
return a[CartesianIndices(a)[i]]
167-
end
168195

169196
function Base.getindex(a::KroneckerVector, i::Integer)
170197
GPUArraysCore.assertscalar("getindex")
171198
k = length(a.b)
172199
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
173200
end
174201

175-
function Base.getindex(a::KroneckerVector, i::CartesianProduct)
176-
return a.a[i.a] a.b[i.b]
177-
end
178-
function Base.getindex(a::KroneckerMatrix, i::CartesianProduct, j::CartesianProduct)
179-
return a.a[i.a, j.a] a.b[i.b, j.b]
202+
## function Base.getindex(a::KroneckerVector, i::CartesianProduct)
203+
## return a.a[i.a] ⊗ a.b[i.b]
204+
## end
205+
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
206+
return a.a[map(Base.Fix2(getfield, :a), I)...] a.b[map(Base.Fix2(getfield, :b), I)...]
180207
end
208+
# Fix ambigiuity error.
209+
Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[]
181210

182211
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
183212
return a.a == b.a && a.b == b.b
@@ -220,7 +249,7 @@ using LinearAlgebra:
220249
svd,
221250
svdvals,
222251
tr
223-
diagonal(a::AbstractVecOrMat) = Diagonal(a)
252+
diagonal(a::AbstractArray) = Diagonal(a)
224253
function diagonal(a::KroneckerArray)
225254
return Diagonal(a.a) Diagonal(a.b)
226255
end
@@ -275,10 +304,10 @@ end
275304
function Base.:*(a::KroneckerQ, b::KroneckerQ)
276305
return (a.a * b.a) (a.b * b.b)
277306
end
278-
function Base.:*(a::KroneckerQ, b::KroneckerMatrix)
307+
function Base.:*(a::KroneckerQ, b::KroneckerArray)
279308
return (a.a * b.a) (a.b * b.b)
280309
end
281-
function Base.:*(a::KroneckerMatrix, b::KroneckerQ)
310+
function Base.:*(a::KroneckerArray, b::KroneckerQ)
282311
return (a.a * b.a) (a.b * b.b)
283312
end
284313
function Base.adjoint(a::KroneckerQ)

test/test_basics.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using KroneckerArrays: KroneckerArrays, , ×, diagonal
1+
using KroneckerArrays: KroneckerArrays, , ×, diagonal, kron_nd
22
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
33
using Test: @test, @testset
44

@@ -39,17 +39,29 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
3939
end
4040
@test tr(a) tr(collect(a))
4141

42+
a = randn(elt, 2, 2, 2) randn(elt, 3, 3, 3)
43+
@test collect(a) kron_nd(a.a, a.b)
44+
@test a[1 × 1, 1 × 1, 1 × 1] == a.a[1, 1, 1] * a.b[1, 1, 1]
45+
@test a[1 × 3, 2 × 1, 2 × 2] == a.a[1, 2, 2] * a.b[3, 1, 2]
46+
@test collect(a + a) 2 * collect(a)
47+
@test collect(2a) 2 * collect(a)
48+
49+
a = randn(elt, 2, 2) randn(elt, 3, 3)
50+
b = randn(elt, 2, 2) randn(elt, 3, 3)
51+
c = a.a b.b
4252
U, S, V = svd(a)
4353
@test collect(U * diagonal(S) * V') collect(a)
4454
@test svdvals(a) S
4555
@test sort(collect(S); rev=true) svdvals(collect(a))
4656
@test collect(U'U) I
4757
@test collect(V * V') I
4858

59+
a = randn(elt, 2, 2) randn(elt, 3, 3)
4960
D, V = eigen(a)
5061
@test collect(a * V) collect(V * diagonal(D))
5162
@test eigvals(a) D
5263

64+
a = randn(elt, 2, 2) randn(elt, 3, 3)
5365
Q, R = qr(a)
5466
@test collect(Q * R) collect(a)
5567
@test collect(Q'Q) I

0 commit comments

Comments
 (0)