Skip to content

Commit 1cd2ce3

Browse files
authored
Rewrite broadcasting and mapping to make it simpler and more general (#23)
1 parent 0b5388f commit 1cd2ce3

File tree

7 files changed

+307
-355
lines changed

7 files changed

+307
-355
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.17"
4+
version = "0.1.18"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/KroneckerArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module KroneckerArrays
22

33
export , ×
44

5+
include("linearcombination.jl")
56
include("cartesianproduct.jl")
67
include("kroneckerarray.jl")
78
include("linearalgebra.jl")

src/fillarrays/kroneckerarray.jl

Lines changed: 79 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -81,144 +81,6 @@ function DerivableInterfaces.zero!(a::EyeEye)
8181
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
8282
end
8383

84-
function Base.:*(a::Number, b::EyeKronecker)
85-
return b.a (a * b.b)
86-
end
87-
function Base.:*(a::Number, b::KroneckerEye)
88-
return (a * b.a) b.b
89-
end
90-
function Base.:*(a::Number, b::EyeEye)
91-
return error("Can't multiply `Eye ⊗ Eye` by a number.")
92-
end
93-
function Base.:*(a::EyeKronecker, b::Number)
94-
return a.a (a.b * b)
95-
end
96-
function Base.:*(a::KroneckerEye, b::Number)
97-
return (a.a * b) a.b
98-
end
99-
function Base.:*(a::EyeEye, b::Number)
100-
return error("Can't multiply `Eye ⊗ Eye` by a number.")
101-
end
102-
103-
function Base.:-(a::EyeKronecker)
104-
return a.a (-a.b)
105-
end
106-
function Base.:-(a::KroneckerEye)
107-
return (-a.a) a.b
108-
end
109-
function Base.:-(a::EyeEye)
110-
return error("Can't multiply `Eye ⊗ Eye` by a number.")
111-
end
112-
113-
for op in (:+, :-)
114-
@eval begin
115-
function Base.$op(a::EyeKronecker, b::EyeKronecker)
116-
if a.a b.a
117-
return throw(
118-
ArgumentError(
119-
"KroneckerArray addition is only supported when the first or secord arguments match.",
120-
),
121-
)
122-
end
123-
return a.a $op(a.b, b.b)
124-
end
125-
function Base.$op(a::KroneckerEye, b::KroneckerEye)
126-
if a.b b.b
127-
return throw(
128-
ArgumentError(
129-
"KroneckerArray addition is only supported when the first or secord arguments match.",
130-
),
131-
)
132-
end
133-
return $op(a.a, b.a) a.b
134-
end
135-
function Base.$op(a::EyeEye, b::EyeEye)
136-
if a.b b.b
137-
return throw(
138-
ArgumentError(
139-
"KroneckerArray addition is only supported when the first or secord arguments match.",
140-
),
141-
)
142-
end
143-
return $op(a.a, b.a) a.b
144-
end
145-
end
146-
end
147-
148-
function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker)
149-
map!(f, dest.b, src.b)
150-
return dest
151-
end
152-
function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye)
153-
map!(f, dest.a, src.a)
154-
return dest
155-
end
156-
function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye)
157-
return error("Can't write in-place.")
158-
end
159-
for f in [:+, :-]
160-
@eval begin
161-
function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
162-
if dest.a a.a b.a
163-
throw(
164-
ArgumentError(
165-
"KroneckerArray addition is only supported when the first or second arguments match.",
166-
),
167-
)
168-
end
169-
map!($f, dest.b, a.b, b.b)
170-
return dest
171-
end
172-
function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
173-
if dest.b a.b b.b
174-
throw(
175-
ArgumentError(
176-
"KroneckerArray addition is only supported when the first or second arguments match.",
177-
),
178-
)
179-
end
180-
map!($f, dest.a, a.a, b.a)
181-
return dest
182-
end
183-
function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye)
184-
return error("Can't write in-place.")
185-
end
186-
end
187-
end
188-
function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker)
189-
map!(f, dest.b, a.b)
190-
return dest
191-
end
192-
function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye)
193-
map!(f, dest.a, a.a)
194-
return dest
195-
end
196-
function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye)
197-
return error("Can't write in-place.")
198-
end
199-
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
200-
map!(f, dest.b, a.b)
201-
return dest
202-
end
203-
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
204-
map!(f, dest.a, a.a)
205-
return dest
206-
end
207-
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
208-
return error("Can't write in-place.")
209-
end
210-
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
211-
map!(f, dest.b, a.b)
212-
return dest
213-
end
214-
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
215-
map!(f, dest.a, a.a)
216-
return dest
217-
end
218-
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
219-
return error("Can't write in-place.")
220-
end
221-
22284
using Base.Broadcast:
22385
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
22486

@@ -233,3 +95,82 @@ Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2
23395
function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type)
23496
return Eye{elt}(axes(bc))
23597
end
98+
99+
function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}})
100+
dest2 = arg2(dest)
101+
f = LinearCombination(a)
102+
args = arguments(a)
103+
arg2s = arg2.(args)
104+
dest2 .= f.(arg2s...)
105+
return dest
106+
end
107+
function Base.copyto!(dest::KroneckerEye, a::Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}})
108+
dest1 = arg1(dest)
109+
f = LinearCombination(a)
110+
args = arguments(a)
111+
arg1s = arg1.(args)
112+
dest1 .= f.(arg1s...)
113+
return dest
114+
end
115+
function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}})
116+
return error("Can't write in-place to `Eye ⊗ Eye`.")
117+
end
118+
119+
# Simplification rules similar to those for FillArrays.jl:
120+
# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl
121+
using FillArrays: Zeros
122+
function Base.broadcasted(
123+
style::KroneckerStyle,
124+
::typeof(+),
125+
a::KroneckerArray,
126+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
127+
)
128+
# TODO: Promote the element types.
129+
return a
130+
end
131+
function Base.broadcasted(
132+
style::KroneckerStyle,
133+
::typeof(+),
134+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
135+
b::KroneckerArray,
136+
)
137+
# TODO: Promote the element types.
138+
return b
139+
end
140+
function Base.broadcasted(
141+
style::KroneckerStyle,
142+
::typeof(+),
143+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
144+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
145+
)
146+
# TODO: Promote the element types and axes.
147+
return b
148+
end
149+
function Base.broadcasted(
150+
style::KroneckerStyle,
151+
::typeof(-),
152+
a::KroneckerArray,
153+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
154+
)
155+
# TODO: Promote the element types.
156+
return a
157+
end
158+
function Base.broadcasted(
159+
style::KroneckerStyle,
160+
::typeof(-),
161+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
162+
b::KroneckerArray,
163+
)
164+
# TODO: Promote the element types.
165+
# TODO: Return `broadcasted(-, b)`.
166+
return -b
167+
end
168+
function Base.broadcasted(
169+
style::KroneckerStyle,
170+
::typeof(-),
171+
a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
172+
b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros},
173+
)
174+
# TODO: Promote the element types and axes.
175+
return b
176+
end

0 commit comments

Comments
 (0)