Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"

[compat]
DerivableInterfaces = "0.4.5"
GPUArraysCore = "0.2.0"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2.0"
julia = "1.10"
239 changes: 233 additions & 6 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module KroneckerArrays

using GPUArraysCore: GPUArraysCore

export ⊗, ×

struct CartesianProduct{A,B}
Expand Down Expand Up @@ -28,6 +30,26 @@ end
Base.first(r::CartesianProductUnitRange) = first(r.range)
Base.last(r::CartesianProductUnitRange) = last(r.range)

function Base.axes(r::CartesianProductUnitRange)
return (CartesianProductUnitRange(r.product, only(axes(r.range))),)
end

using Base.Broadcast: DefaultArrayStyle
for f in (:+, :-)
@eval begin
function Broadcast.broadcasted(
::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer
)
return CartesianProductUnitRange(r.product, $f.(r.range, x))
end
function Broadcast.broadcasted(
::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange
)
return CartesianProductUnitRange(r.product, $f.(x, r.range))
end
end
end

struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
a::A
b::B
Expand All @@ -44,6 +66,15 @@ end
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}

function Base.copy(a::KroneckerArray)
return copy(a.a) ⊗ copy(a.b)
end
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
copyto!(dest.a, src.a)
copyto!(dest.b, src.b)
return dest
end

function Base.similar(
a::AbstractArray,
elt::Type,
Expand Down Expand Up @@ -73,9 +104,21 @@ function Base.similar(
return similar(arrayt, map(ax -> ax.product.a, axs)) ⊗
similar(arrayt, map(ax -> ax.product.b, axs))
end
function Base.similar(
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
axs::Tuple{
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
},
) where {A,B}
return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs))
end

Base.collect(a::KroneckerArray) = kron(a.a, a.b)

function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
return convert(Array{T,N}, collect(a))
end

Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a))

function Base.axes(a::KroneckerArray)
Expand Down Expand Up @@ -107,12 +150,23 @@ end
⊗(a::Number, b::AbstractVecOrMat) = a * b
⊗(a::AbstractVecOrMat, b::Number) = a * b

function Base.getindex(::KroneckerArray, ::Int)
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))
function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
GPUArraysCore.assertscalar("getindex")
# Code logic from Kronecker.jl:
# https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
k, l = size(a.b)
return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
end
function Base.getindex(a::KroneckerMatrix, i::Integer)
return a[CartesianIndices(a)[i]]
end
function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N}
return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported."))

function Base.getindex(a::KroneckerVector, i::Integer)
GPUArraysCore.assertscalar("getindex")
k = length(a.b)
return a.a[cld(i, k)] * a.b[(i - 1) % k + 1]
end

function Base.getindex(a::KroneckerVector, i::CartesianProduct)
return a.a[i.a] ⊗ a.b[i.b]
end
Expand Down Expand Up @@ -169,9 +223,18 @@ end
function Base.:*(a::KroneckerArray, b::KroneckerArray)
return (a.a * b.a) ⊗ (a.b * b.b)
end
function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
function LinearAlgebra.mul!(
c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number
)
iszero(β) ||
iszero(c) ||
throw(
ArgumentError(
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
),
)
mul!(c.a, a.a, b.a)
mul!(c.b, a.b, b.b)
mul!(c.b, a.b, b.b, α, β)
return c
end
function LinearAlgebra.tr(a::KroneckerArray)
Expand Down Expand Up @@ -269,4 +332,168 @@ for op in (:+, :-)
end
end

function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
dest.a .= a.a
dest.b .= a.b
return dest
end
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
if a.b == b.b
map!(+, dest.a, a.a, b.a)
dest.b .= a.b
elseif a.a == b.a
dest.a .= a.a
map!(+, dest.b, a.b, b.b)
else
throw(
ArgumentError(
"KroneckerArray addition is only supported when the first or second arguments match.",
),
)
end
return dest
end
function Base.map!(
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= f.x .* a.a
dest.b .= a.b
return dest
end
function Base.map!(
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
)
dest.a .= a.a
dest.b .= a.b .* f.x
return dest
end

using DerivableInterfaces: DerivableInterfaces, zero!
function DerivableInterfaces.zero!(a::KroneckerArray)
zero!(a.a)
zero!(a.b)
return a
end

using MatrixAlgebraKit:
MatrixAlgebraKit,
AbstractAlgorithm,
TruncationStrategy,
default_eig_algorithm,
default_eigh_algorithm,
default_lq_algorithm,
default_polar_algorithm,
default_qr_algorithm,
default_svd_algorithm,
eig_full!,
eig_trunc!,
eig_vals!,
eigh_full!,
eigh_trunc!,
eigh_vals!,
initialize_output,
left_null!,
left_orth!,
left_polar!,
lq_compact!,
lq_full!,
qr_compact!,
qr_full!,
right_null!,
right_orth!,
right_polar!,
svd_compact!,
svd_full!,
svd_trunc!,
svd_vals!,
truncate!

struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
a::A
b::B
end

for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
ff = Symbol("default_", f, "_algorithm")
@eval begin
function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...)
return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...))
end
end
end

for f in (
:eig_full!,
:eigh_full!,
:qr_compact!,
:qr_full!,
:left_polar!,
:lq_compact!,
:lq_full!,
:right_polar!,
:svd_compact!,
:svd_full!,
)
@eval begin
function MatrixAlgebraKit.initialize_output(
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
)
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
end
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
return F
end
end
end

for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
@eval begin
function MatrixAlgebraKit.initialize_output(
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
)
return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b)
end
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm)
$f(a.a, F.a, alg.a)
$f(a.b, F.b, alg.b)
return F
end
end
end

for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
@eval begin
function MatrixAlgebraKit.truncate!(
::typeof($f),
(D, V)::Tuple{KroneckerMatrix,KroneckerMatrix},
strategy::TruncationStrategy,
)
return throw(MethodError(truncate!, ($f, (D, V), strategy)))
end
end
end

for f in (:left_orth!, :right_orth!)
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
end
end
end

for f in (:left_null!, :right_null!)
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
return initialize_output($f, a.a) ⊗ initialize_output($f, a.b)
end
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
$f(a.a, F.a; kwargs...)
$f(a.b, F.b; kwargs...)
return F
end
end
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
Loading
Loading