Skip to content

Commit 2ffe0a0

Browse files
authored
Use Summed from MapBroadcast.jl (#37)
1 parent 0f2daf7 commit 2ffe0a0

File tree

4 files changed

+26
-21
lines changed

4 files changed

+26
-21
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.28"
4+
version = "0.1.29"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -33,7 +33,7 @@ DiagonalArrays = "0.3.11"
3333
FillArrays = "1.13"
3434
GPUArraysCore = "0.2"
3535
LinearAlgebra = "1.10"
36-
MapBroadcast = "0.1.9"
36+
MapBroadcast = "0.1.10"
3737
MatrixAlgebraKit = "0.2"
3838
TensorAlgebra = "0.3.10"
3939
TensorProducts = "0.1.7"

src/KroneckerArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module KroneckerArrays
22

33
export , ×
44

5-
include("linearcombination.jl")
65
include("cartesianproduct.jl")
76
include("kroneckerarray.jl")
87
include("linearalgebra.jl")

src/fillarrays/kroneckerarray.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,27 @@ function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
123123
end
124124

125125
# TODO: Define in terms of `_copyto!!` that is called on each argument.
126-
function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}})
126+
function Base.copyto!(dest::EyeKronecker, a::Summed{<:KroneckerStyle{<:Any,EyeStyle()}})
127127
dest2 = arg2(dest)
128128
f = LinearCombination(a)
129-
args = arguments(a)
129+
args = MapBroadcast.arguments(a)
130130
arg2s = arg2.(args)
131131
dest2 .= f.(arg2s...)
132132
return dest
133133
end
134-
function Base.copyto!(dest::KroneckerEye, a::Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}})
134+
function Base.copyto!(
135+
dest::KroneckerEye, a::Summed{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}}
136+
)
135137
dest1 = arg1(dest)
136138
f = LinearCombination(a)
137-
args = arguments(a)
139+
args = MapBroadcast.arguments(a)
138140
arg1s = arg1.(args)
139141
dest1 .= f.(arg1s...)
140142
return dest
141143
end
142-
function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}})
144+
function Base.copyto!(
145+
dest::EyeEye, a::Summed{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}}
146+
)
143147
return error("Can't write in-place to `Eye ⊗ Eye`.")
144148
end
145149

@@ -162,25 +166,25 @@ function Base.similar(bc::Broadcasted{<:DeltaStyle}, elt::Type)
162166
end
163167

164168
# TODO: Dispatch on `DeltaStyle`.
165-
function Base.copyto!(dest::DeltaKronecker, a::Sum{<:KroneckerStyle})
169+
function Base.copyto!(dest::DeltaKronecker, a::Summed{<:KroneckerStyle})
166170
dest2 = arg2(dest)
167171
f = LinearCombination(a)
168-
args = arguments(a)
172+
args = MapBroadcast.arguments(a)
169173
arg2s = arg2.(args)
170174
dest2 .= f.(arg2s...)
171175
return dest
172176
end
173177
# TODO: Dispatch on `DeltaStyle`.
174-
function Base.copyto!(dest::KroneckerDelta, a::Sum{<:KroneckerStyle})
178+
function Base.copyto!(dest::KroneckerDelta, a::Summed{<:KroneckerStyle})
175179
dest1 = arg1(dest)
176180
f = LinearCombination(a)
177-
args = arguments(a)
181+
args = MapBroadcast.arguments(a)
178182
arg1s = arg1.(args)
179183
dest1 .= f.(arg1s...)
180184
return dest
181185
end
182186
# TODO: Dispatch on `DeltaStyle`.
183-
function Base.copyto!(dest::DeltaDelta, a::Sum{<:KroneckerStyle})
187+
function Base.copyto!(dest::DeltaDelta, a::Summed{<:KroneckerStyle})
184188
return error("Can't write in-place to `Delta ⊗ Delta`.")
185189
end
186190

src/kroneckerarray.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) wh
4949
copyto!(dest, src)
5050
return dest
5151
end
52+
using Base.Broadcast: Broadcasted
5253
function _copyto!!(dest::AbstractArray, src::Broadcasted)
5354
copyto!(dest, src)
5455
return dest
@@ -368,11 +369,12 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke
368369
return dest
369370
end
370371

371-
function Base.copyto!(dest::KroneckerArray, a::Sum{<:KroneckerStyle})
372+
using MapBroadcast: MapBroadcast, LinearCombination, Summed
373+
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
372374
dest1 = arg1(dest)
373375
dest2 = arg2(dest)
374376
f = LinearCombination(a)
375-
args = arguments(a)
377+
args = MapBroadcast.arguments(a)
376378
arg1s = arg1.(args)
377379
arg2s = arg2.(args)
378380
if allequal(arg2s)
@@ -393,26 +395,26 @@ end
393395

394396
# Linear operations.
395397
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a, b)
396-
return Sum(a) + Sum(b)
398+
return Summed(a) + Summed(b)
397399
end
398400
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a, b)
399-
return Sum(a) - Sum(b)
401+
return Summed(a) - Summed(b)
400402
end
401403
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a)
402-
return c * Sum(a)
404+
return c * Summed(a)
403405
end
404406
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number)
405-
return Sum(a) * c
407+
return Summed(a) * c
406408
end
407409
# Fix ambiguity error.
408410
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number)
409411
return a * b
410412
end
411413
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number)
412-
return Sum(a) / c
414+
return Summed(a) / c
413415
end
414416
function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a)
415-
return -Sum(a)
417+
return -Summed(a)
416418
end
417419

418420
# Rewrite rules to canonicalize broadcast expressions.

0 commit comments

Comments
 (0)