From f617b5c2f218a93bb358022793d227746b7be864 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 14 Aug 2025 14:12:44 -0400 Subject: [PATCH 1/4] Use Summed from MapBroadcast --- Project.toml | 2 +- src/KroneckerArrays.jl | 1 - src/fillarrays/kroneckerarray.jl | 24 ++++++++++++++---------- src/kroneckerarray.jl | 18 ++++++++++-------- test/Project.toml | 1 + 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index c13cd4e..d5b8eb5 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" -MapBroadcast = "0.1.9" +MapBroadcast = "0.1.11" MatrixAlgebraKit = "0.2" TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 28bdf3e..4552a2f 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -2,7 +2,6 @@ module KroneckerArrays export ⊗, × -include("linearcombination.jl") include("cartesianproduct.jl") include("kroneckerarray.jl") include("linearalgebra.jl") diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 0faeb48..90b3e14 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -123,23 +123,27 @@ function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) end # TODO: Define in terms of `_copyto!!` that is called on each argument. -function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}}) +function Base.copyto!(dest::EyeKronecker, a::Summed{<:KroneckerStyle{<:Any,EyeStyle()}}) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg2s = arg2.(args) dest2 .= f.(arg2s...) return dest end -function Base.copyto!(dest::KroneckerEye, a::Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}}) +function Base.copyto!( + dest::KroneckerEye, a::Summed{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}} +) dest1 = arg1(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) dest1 .= f.(arg1s...) return dest end -function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}}) +function Base.copyto!( + dest::EyeEye, a::Summed{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}} +) return error("Can't write in-place to `Eye ⊗ Eye`.") end @@ -162,25 +166,25 @@ function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type) end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::DeltaKronecker, a::Summed{<:KroneckerStyle}) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg2s = arg2.(args) dest2 .= f.(arg2s...) return dest end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::KroneckerDelta, a::Summed{<:KroneckerStyle}) dest1 = arg1(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) dest1 .= f.(arg1s...) return dest end # TODO: Dispatch on `DeltaStyle`. -function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle}) +function Base.copyto!(dest::DeltaDelta, a::Summed{<:KroneckerStyle}) return error("Can't write in-place to `Delta ⊗ Delta`.") end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 9d30a08..bde7b33 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -49,6 +49,7 @@ function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) wh copyto!(dest, src) return dest end +using Base.Broadcast: Broadcasted function _copyto!!(dest::AbstractArray, src::Broadcasted) copyto!(dest, src) return dest @@ -368,11 +369,12 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke return dest end -function Base.copyto!(dest::KroneckerArray, a::Sum{<:KroneckerStyle}) +using MapBroadcast: MapBroadcast, LinearCombination, Summed +function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle}) dest1 = arg1(dest) dest2 = arg2(dest) f = LinearCombination(a) - args = arguments(a) + args = MapBroadcast.arguments(a) arg1s = arg1.(args) arg2s = arg2.(args) if allequal(arg2s) @@ -393,26 +395,26 @@ end # Linear operations. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a, b) - return Sum(a) + Sum(b) + return Summed(a) + Summed(b) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a, b) - return Sum(a) - Sum(b) + return Summed(a) - Summed(b) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) - return c * Sum(a) + return c * Summed(a) end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) - return Sum(a) * c + return Summed(a) * c end # Fix ambiguity error. function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) return a * b end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) - return Sum(a) / c + return Summed(a) / c end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a) - return -Sum(a) + return -Summed(a) end # Rewrite rules to canonicalize broadcast expressions. diff --git a/test/Project.toml b/test/Project.toml index f37978a..49918f3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" From 1e229ba7b699d6b30e1643986d7067afde28e2d6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 14 Aug 2025 14:48:57 -0400 Subject: [PATCH 2/4] Fix compat version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d5b8eb5..b989c52 100644 --- a/Project.toml +++ b/Project.toml @@ -33,7 +33,7 @@ DiagonalArrays = "0.3.11" FillArrays = "1.13" GPUArraysCore = "0.2" LinearAlgebra = "1.10" -MapBroadcast = "0.1.11" +MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.2" TensorAlgebra = "0.3.10" TensorProducts = "0.1.7" From e915ffbf43b66655a4c2af4ab0e3e30d07ca71cb Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 14 Aug 2025 14:54:51 -0400 Subject: [PATCH 3/4] Remove stale test dep --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 49918f3..f37978a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,6 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" From 88e9fbd2717c2f90f27b6f7f7f6d7a7f1b909d27 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 14 Aug 2025 19:48:00 -0400 Subject: [PATCH 4/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b989c52..517d91c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.28" +version = "0.1.29" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"