Skip to content
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.27"
version = "0.1.28"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -16,22 +16,25 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"

[extensions]
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
KroneckerArraysTensorAlgebraExt = "TensorAlgebra"
KroneckerArraysTensorProductsExt = "TensorProducts"

[compat]
Adapt = "4.3"
BlockArrays = "1.6"
BlockSparseArrays = "0.9"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3.5"
DerivableInterfaces = "0.5.3"
DiagonalArrays = "0.3.11"
FillArrays = "1.13"
GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.9"
MatrixAlgebraKit = "0.2"
TensorAlgebra = "0.3.10"
TensorProducts = "0.1.7"
julia = "1.10"
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using KroneckerArrays:
_similar

function KroneckerArrays.arg1(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
return mortar_axis(arg1.(eachblockaxis(r)))
end
function KroneckerArrays.arg2(r::AbstractBlockedUnitRange)
return mortar_axis(arg2.(eachblockaxis(r)))
Expand All @@ -56,15 +56,14 @@ function block_axes(ax::NTuple{N,AbstractUnitRange{<:Integer}}, I::Block{N}) whe
return block_axes(ax, Tuple(I)...)
end

## TODO: Is this needed?
function Base.getindex(
a::ZeroBlocks{2,KroneckerMatrix{T,A,B}}, I::Vararg{Int,2}
) where {T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}}
ax_a1 = arg1.(a.parentaxes)
ax_a1 = map(arg1, a.parentaxes)
a1 = ZeroBlocks{2,A}(ax_a1)[I...]

ax_a2 = arg2.(a.parentaxes)
ax_a2 = map(arg2, a.parentaxes)
a2 = ZeroBlocks{2,B}(ax_a2)[I...]

return a1 ⊗ a2
end
function Base.getindex(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module KroneckerArraysTensorAlgebraExt

using KroneckerArrays: KroneckerArrays, KroneckerArray, ⊗, arg1, arg2
using TensorAlgebra:
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize

struct KroneckerFusion{A<:FusionStyle,B<:FusionStyle} <: FusionStyle
a::A
b::B
end
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
function TensorAlgebra.FusionStyle(a::KroneckerArray)
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
end
function matricize_kronecker(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize(arg1(style), arg1(a), biperm) ⊗ matricize(arg2(style), arg2(a), biperm)
end
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
return matricize_kronecker(style, a, biperm)
end
# Fix ambiguity error.
# TODO: Investigate rewriting the logic in `TensorAlgebra.jl` to avoid this.
using TensorAlgebra: BlockedTrivialPermutation, unmatricize
function TensorAlgebra.matricize(
style::KroneckerFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2}
)
return matricize_kronecker(style, a, biperm)
end
function unmatricize_kronecker(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize(arg1(style), arg1(a), arg1.(ax)) ⊗
unmatricize(arg2(style), arg2(a), arg2.(ax))
end
function TensorAlgebra.unmatricize(style::KroneckerFusion, a::AbstractArray, ax)
return unmatricize_kronecker(style, a, ax)
end

end
6 changes: 6 additions & 0 deletions src/cartesianproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ unproduct(r::CartesianProductUnitRange) = getfield(r, :range)
arg1(a::CartesianProductUnitRange) = arg1(cartesianproduct(a))
arg2(a::CartesianProductUnitRange) = arg2(cartesianproduct(a))

function Base.getindex(a::CartesianProductUnitRange, i::CartesianProductUnitRange)
prod = cartesianproduct(a)[cartesianproduct(i)]
range = unproduct(a)[unproduct(i)]
return cartesianrange(prod, range)
end

function Base.show(io::IO, a::CartesianProductUnitRange)
show(io, unproduct(a))
return nothing
Expand Down
93 changes: 89 additions & 4 deletions src/fillarrays/kroneckerarray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using FillArrays: FillArrays, Zeros
using FillArrays: FillArrays, Ones, Zeros
function FillArrays.fillsimilar(
a::Zeros{T},
ax::Tuple{
Expand All @@ -21,6 +21,11 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

using DiagonalArrays: Delta
const DeltaKronecker{T,N,A<:Delta{T,N},B<:AbstractArray{T,N}} = KroneckerArray{T,N,A,B}
const KroneckerDelta{T,N,A<:AbstractArray{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}
const DeltaDelta{T,N,A<:Delta{T,N},B<:Delta{T,N}} = KroneckerArray{T,N,A,B}

_getindex(a::Eye, I1::Colon, I2::Colon) = a
_getindex(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_getindex(a::Eye, I1::Base.Slice, I2::Colon) = a
Expand All @@ -30,15 +35,23 @@ _view(a::Eye, I1::Base.Slice, I2::Base.Slice) = a
_view(a::Eye, I1::Base.Slice, I2::Colon) = a
_view(a::Eye, I1::Colon, I2::Base.Slice) = a

function _getindex(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
return a
end
function _view(a::Delta, I1::Union{Colon,Base.Slice}, Irest::Union{Colon,Base.Slice}...)
return a
end

# Like `adapt` but preserves `Eye`.
_adapt(to, a::Eye) = a
_adapt(to, a::Delta) = a

# Allows customizing for `FillArrays.Eye`.
function _convert(::Type{AbstractArray{T}}, a::RectDiagonal) where {T}
_convert(AbstractMatrix{T}, a)
return _convert(AbstractMatrix{T}, a)
end
function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
return RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
end

# Like `similar` but preserves `Eye`, `Ones`, etc.
Expand All @@ -61,8 +74,33 @@ function _similar(arrayt::Type{<:SquareEye}, axs::NTuple{2,AbstractUnitRange})
return Eye{eltype(arrayt)}((only(unique(axs)),))
end

# Like `copy` but preserves `Eye`.
function _similar(a::Delta, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
return Delta{elt}(axs)
end
function _similar(arrayt::Type{<:Delta}, axs::Tuple{Vararg{AbstractUnitRange}})
return Delta{eltype(arrayt)}(axs)
end

# Like `copy` but preserves `Eye`/`Delta`.
_copy(a::Eye) = a
_copy(a::Delta) = a

function _copyto!!(dest::Eye{<:Any,N}, src::Eye{<:Any,N}) where {N}
size(dest) == size(src) ||
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
return dest
end
function _copyto!!(dest::Delta{<:Any,N}, src::Delta{<:Any,N}) where {N}
size(dest) == size(src) ||
throw(ArgumentError("Sizes do not match: $(size(dest)) != $(size(src))."))
return dest
end

function _permutedims!!(dest::Delta, src::Delta, perm)
Base.PermutedDimsArrays.genperm(axes(src), perm) == axes(dest) ||
throw(ArgumentError("Permuted axes do not match."))
return dest
end

using Base.Broadcast:
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
Expand All @@ -75,10 +113,16 @@ end
Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle()
Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2

function _copyto!!(dest::Eye, src::Broadcasted{<:EyeStyle,<:Any,typeof(identity)})
axes(dest) == axes(src) || error("Dimension mismatch.")
return dest
end

function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
return Eye{elt}(axes(bc))
end

# TODO: Define in terms of `_copyto!!` that is called on each argument.
function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}})
dest2 = arg2(dest)
f = LinearCombination(a)
Expand All @@ -99,6 +143,47 @@ function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),Eye
return error("Can't write in-place to `Eye ⊗ Eye`.")
end

struct DeltaStyle{N} <: AbstractArrayStyle{N} end
DeltaStyle(::Val{N}) where {N} = DeltaStyle{N}()
DeltaStyle{M}(::Val{N}) where {M,N} = DeltaStyle{N}()
function _BroadcastStyle(A::Type{<:Delta})
return DeltaStyle{ndims(A)}()
end
Base.BroadcastStyle(style1::DeltaStyle, style2::DeltaStyle) = DeltaStyle()
Base.BroadcastStyle(style1::DeltaStyle, style2::DefaultArrayStyle) = style2

function _copyto!!(dest::Delta, src::Broadcasted{<:DeltaStyle,<:Any,typeof(identity)})
axes(dest) == axes(src) || error("Dimension mismatch.")
return dest
end

function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type)
return Delta{elt}(axes(bc))
end

# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle})
dest2 = arg2(dest)
f = LinearCombination(a)
args = arguments(a)
arg2s = arg2.(args)
dest2 .= f.(arg2s...)
return dest
end
# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle})
dest1 = arg1(dest)
f = LinearCombination(a)
args = arguments(a)
arg1s = arg1.(args)
dest1 .= f.(arg1s...)
return dest
end
# TODO: Dispatch on `DeltaStyle`.
function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle})
return error("Can't write in-place to `Delta ⊗ Delta`.")
end

# Simplification rules similar to those for FillArrays.jl:
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
using FillArrays: Zeros
Expand Down
52 changes: 46 additions & 6 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,26 @@ _copy(a::AbstractArray) = copy(a)
function Base.copy(a::KroneckerArray)
return _copy(arg1(a)) ⊗ _copy(arg2(a))
end
function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
copyto!(arg1(dest), arg1(src))
copyto!(arg2(dest), arg2(src))

# Allows extra customization, like for `FillArrays.Eye`.
function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N}
copyto!(dest, src)
return dest
end
function _copyto!!(dest::AbstractArray, src::Broadcasted)
copyto!(dest, src)
return dest
end

function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N}
return copyto!_kronecker(dest, src)
end
function copyto!_kronecker(
dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}
) where {N}
# TODO: Check if neither argument is mutated and if so error.
_copyto!!(arg1(dest), arg1(src))
_copyto!!(arg2(dest), arg2(src))
return dest
end

Expand Down Expand Up @@ -110,6 +127,23 @@ function Base.similar(
return similar(promote_type(A, B), sz)
end

function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm)
permutedims!(dest, src, perm)
return dest
end

using DerivableInterfaces: DerivableInterfaces, permuteddims
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
return permuteddims(arg1(a), perm) ⊗ permuteddims(arg2(a), perm)
end

function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
# TODO: Error if neither argument is mutable.
_permutedims!!(arg1(dest), arg1(src), perm)
_permutedims!!(arg2(dest), arg2(src), perm)
return dest
end

function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
return (t[1]..., flatten(Base.tail(t))...)
end
Expand All @@ -128,7 +162,7 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
a′ = reshape(a, interleave(size(a), ntuple(one, N)))
b′ = reshape(b, interleave(ntuple(one, N), size(b)))
c′ = permutedims(a′ .* b′, reverse(ntuple(identity, 2N)))
sz = ntuple(i -> size(a, i) * size(b, i), N)
sz = reverse(ntuple(i -> size(a, i) * size(b, i), N))
return permutedims(reshape(c′, sz), reverse(ntuple(identity, N)))
end
kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
Expand Down Expand Up @@ -284,6 +318,12 @@ for f in [:transpose, :adjoint, :inv]
end
end

function Base.reshape(
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange,Vararg{CartesianProductUnitRange}}
)
return reshape(arg1(a), map(arg1, ax)) ⊗ reshape(arg2(a), map(arg2, ax))
end

# Allows for customizations for FillArrays.
_BroadcastStyle(x) = BroadcastStyle(x)

Expand Down Expand Up @@ -405,8 +445,8 @@ Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
Broadcast.broadcastable(a::KroneckerBroadcasted) = a
Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a))
function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted)
copyto!(arg1(dest), copy(arg1(a)))
copyto!(arg2(dest), copy(arg2(a)))
_copyto!!(arg1(dest), arg1(a))
_copyto!!(arg2(dest), arg2(a))
return dest
end
function Base.eltype(a::KroneckerBroadcasted)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Expand All @@ -34,6 +35,7 @@ MatrixAlgebraKit = "0.2"
SafeTestsets = "0.1"
StableRNGs = "1.0"
Suppressor = "0.2"
TensorAlgebra = "0.3.10"
TensorProducts = "0.1.7"
Test = "1.10"
TestExtras = "0.3"
Loading
Loading