diff --git a/Project.toml b/Project.toml index f0ba43d..b0cf377 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index c072dbe..c910c7d 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -118,7 +118,29 @@ function Base.similar( return similar(promote_type(A, B), sz) end -Base.collect(a::KroneckerArray) = kron(a.a, a.b) +function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}}) + return (t[1]..., flatten(Base.tail(t))...) +end +function flatten(t::Tuple{Tuple}) + return t[1] +end +flatten(::Tuple{}) = () +function interleave(x::Tuple, y::Tuple) + length(x) == length(y) || throw(ArgumentError("Tuples must have the same length.")) + xy = ntuple(i -> (x[i], y[i]), length(x)) + return flatten(xy) +end +function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} + a′ = reshape(a, interleave(size(a), ntuple(one, N))) + b′ = reshape(b, interleave(ntuple(one, N), size(b))) + c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N))) + sz = ntuple(i -> size(a, i) * size(b, i), N) + return permutedims(reshape(c′, sz), reverse(ntuple(identity, N))) +end +kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b) +kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) + +Base.collect(a::KroneckerArray) = kron_nd(a.a, a.b) function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) @@ -150,10 +172,18 @@ function Base.show(io::IO, a::KroneckerArray) return nothing end -⊗(a::AbstractVecOrMat, b::AbstractVecOrMat) = KroneckerArray(a, b) +⊗(a::AbstractArray, b::AbstractArray) = KroneckerArray(a, b) ⊗(a::Number, b::Number) = a * b -⊗(a::Number, b::AbstractVecOrMat) = a * b -⊗(a::AbstractVecOrMat, b::Number) = a * b +⊗(a::Number, b::AbstractArray) = a * b +⊗(a::AbstractArray, b::Number) = a * b + +function Base.getindex(a::KroneckerArray, i::Integer) + return a[CartesianIndices(a)[i]] +end + +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} + return error("Not implemented.") +end function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) GPUArraysCore.assertscalar("getindex") @@ -162,9 +192,6 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) k, l = size(a.b) return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] end -function Base.getindex(a::KroneckerMatrix, i::Integer) - return a[CartesianIndices(a)[i]] -end function Base.getindex(a::KroneckerVector, i::Integer) GPUArraysCore.assertscalar("getindex") @@ -172,12 +199,14 @@ function Base.getindex(a::KroneckerVector, i::Integer) return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] end -function Base.getindex(a::KroneckerVector, i::CartesianProduct) - return a.a[i.a] ⊗ a.b[i.b] -end -function Base.getindex(a::KroneckerMatrix, i::CartesianProduct, j::CartesianProduct) - return a.a[i.a, j.a] ⊗ a.b[i.b, j.b] +## function Base.getindex(a::KroneckerVector, i::CartesianProduct) +## return a.a[i.a] ⊗ a.b[i.b] +## end +function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} + return a.a[map(Base.Fix2(getfield, :a), I)...] ⊗ a.b[map(Base.Fix2(getfield, :b), I)...] end +# Fix ambigiuity error. +Base.getindex(a::KroneckerArray{<:Any,0}) = a.a[] * a.b[] function Base.:(==)(a::KroneckerArray, b::KroneckerArray) return a.a == b.a && a.b == b.b @@ -220,7 +249,7 @@ using LinearAlgebra: svd, svdvals, tr -diagonal(a::AbstractVecOrMat) = Diagonal(a) +diagonal(a::AbstractArray) = Diagonal(a) function diagonal(a::KroneckerArray) return Diagonal(a.a) ⊗ Diagonal(a.b) end @@ -275,10 +304,10 @@ end function Base.:*(a::KroneckerQ, b::KroneckerQ) return (a.a * b.a) ⊗ (a.b * b.b) end -function Base.:*(a::KroneckerQ, b::KroneckerMatrix) +function Base.:*(a::KroneckerQ, b::KroneckerArray) return (a.a * b.a) ⊗ (a.b * b.b) end -function Base.:*(a::KroneckerMatrix, b::KroneckerQ) +function Base.:*(a::KroneckerArray, b::KroneckerQ) return (a.a * b.a) ⊗ (a.b * b.b) end function Base.adjoint(a::KroneckerQ) diff --git a/test/test_basics.jl b/test/test_basics.jl index 17892cc..1c10b53 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,4 +1,4 @@ -using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal +using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr using Test: @test, @testset @@ -39,6 +39,16 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) end @test tr(a) ≈ tr(collect(a)) + a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) + @test collect(a) ≈ kron_nd(a.a, a.b) + @test a[1 × 1, 1 × 1, 1 × 1] == a.a[1, 1, 1] * a.b[1, 1, 1] + @test a[1 × 3, 2 × 1, 2 × 2] == a.a[1, 2, 2] * a.b[3, 1, 2] + @test collect(a + a) ≈ 2 * collect(a) + @test collect(2a) ≈ 2 * collect(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c = a.a ⊗ b.b U, S, V = svd(a) @test collect(U * diagonal(S) * V') ≈ collect(a) @test svdvals(a) ≈ S @@ -46,10 +56,12 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) @test collect(U'U) ≈ I @test collect(V * V') ≈ I + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) D, V = eigen(a) @test collect(a * V) ≈ collect(V * diagonal(D)) @test eigvals(a) ≈ D + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) Q, R = qr(a) @test collect(Q * R) ≈ collect(a) @test collect(Q'Q) ≈ I