diff --git a/Project.toml b/Project.toml index c13cd4e..517d91c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.28" +version = "0.1.29" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -33,7 +33,7 @@ DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" -MapBroadcast = "0.1.9" +MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.2" TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 28bdf3e..4552a2f 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -2,7 +2,6 @@ module KroneckerArrays export ⊗, × -include("linearcombination.jl") include("cartesianproduct.jl") include("kroneckerarray.jl") include("linearalgebra.jl") diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 0faeb48..90b3e14 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -123,23 +123,27 @@ function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) end # TODO: Define in terms of `_copyto!!` that is called on each argument. -function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}}) +function Base.copyto!(dest::EyeKronecker, a::Summed{<:KroneckerStyle{<:Any,EyeStyle()}}) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg2s = arg2.(args) dest2 .= f.(arg2s...) return dest end -function Base.copyto!(dest::KroneckerEye, a::Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}}) +function Base.copyto!( + dest::KroneckerEye, a::Summed{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}} +) dest1 = arg1(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) dest1 .= f.(arg1s...) return dest end -function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}}) +function Base.copyto!( + dest::EyeEye, a::Summed{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}} +) return error("Can't write in-place to `Eye ⊗ Eye`.") end @@ -162,25 +166,25 @@ function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type) end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::DeltaKronecker, a::Summed{<:KroneckerStyle}) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg2s = arg2.(args) dest2 .= f.(arg2s...) return dest end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::KroneckerDelta, a::Summed{<:KroneckerStyle}) dest1 = arg1(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) dest1 .= f.(arg1s...) return dest end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::DeltaDelta, a::Summed{<:KroneckerStyle}) return error("Can't write in-place to `Delta ⊗ Delta`.") end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 9d30a08..bde7b33 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -49,6 +49,7 @@ function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) wh copyto!(dest, src) return dest end +using Base.Broadcast: Broadcasted function _copyto!!(dest::AbstractArray, src::Broadcasted) copyto!(dest, src) return dest @@ -368,11 +369,12 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke return dest end -function Base.copyto!(dest::KroneckerArray, a::Sum{<:KroneckerStyle}) +using MapBroadcast: MapBroadcast, LinearCombination, Summed +function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle}) dest1 = arg1(dest) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) arg2s = arg2.(args) if allequal(arg2s) @@ -393,26 +395,26 @@ end # Linear operations. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a, b) - return Sum(a) + Sum(b) + return Summed(a) + Summed(b) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a, b) - return Sum(a) - Sum(b) + return Summed(a) - Summed(b) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) - return c * Sum(a) + return c * Summed(a) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) - return Sum(a) * c + return Summed(a) * c end # Fix ambiguity error. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) return a * b end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) - return Sum(a) / c + return Summed(a) / c end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a) - return -Sum(a) + return -Summed(a) end # Rewrite rules to canonicalize broadcast expressions.