Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.2"
version = "0.1.3"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Expand Down
59 changes: 44 additions & 15 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,29 @@ function Base.similar(
return similar(promote_type(A, B), sz)
end

Base.collect(a::KroneckerArray) = kron(a.a, a.b)
function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
return (t[1]..., flatten(Base.tail(t))...)
end
function flatten(t::Tuple{Tuple})
return t[1]
end
flatten(::Tuple{}) = ()
function interleave(x::Tuple, y::Tuple)
length(x) == length(y) || throw(ArgumentError("Tuples must have the same length."))
xy = ntuple(i -> (x[i], y[i]), length(x))
return flatten(xy)
end
function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N)))
sz = ntuple(i -> size(a, i) * size(b, i), N)
return permutedims(reshape(c′, sz), reverse(ntuple(identity, N)))
end
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)

Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b)

function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
return convert(Array{T,N}, collect(a))
Expand Down Expand Up @@ -150,10 +172,18 @@ function Base.show(io::IO, a::KroneckerArray)
return nothing
end

⊗(a::AbstractVecOrMat, b::AbstractVecOrMat) = KroneckerArray(a, b)
⊗(a::AbstractArray, b::AbstractArray) = KroneckerArray(a, b)
⊗(a::Number, b::Number) = a * b
⊗(a::Number, b::AbstractVecOrMat) = a * b
⊗(a::AbstractVecOrMat, b::Number) = a * b
⊗(a::Number, b::AbstractArray) = a * b
⊗(a::AbstractArray, b::Number) = a * b

function Base.getindex(a::KroneckerArray, i::Integer)
return a[CartesianIndices(a)[i]]
end

function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
return error("Not implemented.")
end

function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
GPUArraysCore.assertscalar("getindex")
Expand All @@ -162,22 +192,21 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
k, l = size(a.b)
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
end
function Base.getindex(a::KroneckerMatrix, i::Integer)
return a[CartesianIndices(a)[i]]
end

function Base.getindex(a::KroneckerVector, i::Integer)
GPUArraysCore.assertscalar("getindex")
k = length(a.b)
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
end

function Base.getindex(a::KroneckerVector, i::CartesianProduct)
return a.a[i.a] ⊗ a.b[i.b]
end
function Base.getindex(a::KroneckerMatrix, i::CartesianProduct, j::CartesianProduct)
return a.a[i.a, j.a] ⊗ a.b[i.b, j.b]
## function Base.getindex(a::KroneckerVector, i::CartesianProduct)
## return a.a[i.a] ⊗ a.b[i.b]
## end
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...]
end
# Fix ambigiuity error.
Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[]

function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
return a.a == b.a && a.b == b.b
Expand Down Expand Up @@ -220,7 +249,7 @@ using LinearAlgebra:
svd,
svdvals,
tr
diagonal(a::AbstractVecOrMat) = Diagonal(a)
diagonal(a::AbstractArray) = Diagonal(a)
function diagonal(a::KroneckerArray)
return Diagonal(a.a) ⊗ Diagonal(a.b)
end
Expand Down Expand Up @@ -275,10 +304,10 @@ end
function Base.:*(a::KroneckerQ, b::KroneckerQ)
return (a.a * b.a) ⊗ (a.b * b.b)
end
function Base.:*(a::KroneckerQ, b::KroneckerMatrix)
function Base.:*(a::KroneckerQ, b::KroneckerArray)
return (a.a * b.a) ⊗ (a.b * b.b)
end
function Base.:*(a::KroneckerMatrix, b::KroneckerQ)
function Base.:*(a::KroneckerArray, b::KroneckerQ)
return (a.a * b.a) ⊗ (a.b * b.b)
end
function Base.adjoint(a::KroneckerQ)
Expand Down
14 changes: 13 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal
using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
using Test: @test, @testset

Expand Down Expand Up @@ -39,17 +39,29 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
end
@test tr(a) ≈ tr(collect(a))

a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3)
@test collect(a) ≈ kron_nd(a.a, a.b)
@test a[1 × 1, 1 × 1, 1 × 1] == a.a[1, 1, 1] * a.b[1, 1, 1]
@test a[1 × 3, 2 × 1, 2 × 2] == a.a[1, 2, 2] * a.b[3, 1, 2]
@test collect(a + a) ≈ 2 * collect(a)
@test collect(2a) ≈ 2 * collect(a)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c = a.a ⊗ b.b
U, S, V = svd(a)
@test collect(U * diagonal(S) * V') ≈ collect(a)
@test svdvals(a) ≈ S
@test sort(collect(S); rev=true) ≈ svdvals(collect(a))
@test collect(U'U) ≈ I
@test collect(V * V') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
D, V = eigen(a)
@test collect(a * V) ≈ collect(V * diagonal(D))
@test eigvals(a) ≈ D

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
Q, R = qr(a)
@test collect(Q * R) ≈ collect(a)
@test collect(Q'Q) ≈ I
Expand Down
Loading