Skip to content

Commit 1d3aa33

Browse files
author
Wimmerer
committed
in-place assignment for broadcasts, need tests
1 parent 116496b commit 1d3aa33

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

src/operations/broadcasts.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#Broadcasting machinery
22
#######################
33

4+
# YOU SHALL NOT PASS:
5+
# For real though, some of this is far from pretty.
6+
47
valunwrap(::Val{x}) where x = x
58
#This is directly from the Broadcasting interface docs
69
struct GBVectorStyle <: Broadcast.AbstractArrayStyle{1} end
@@ -69,6 +72,81 @@ modifying(::typeof(emul)) = emul!
6972
end
7073
end
7174
end
75+
mutatingop(::typeof(emul)) = emul!
76+
mutatingop(::typeof(eadd)) = eadd!
77+
mutatingop(::typeof(map)) = map!
78+
@inline function Base.copyto!(C::GBArray, bc::Broadcast.Broadcasted{GBMatrixStyle})
79+
l = length(bc.args)
80+
if l == 1
81+
x = first(bc.args)
82+
if bc.f === Base.identity
83+
C[:,:, accum=second] = x
84+
return C
85+
end
86+
return map!(bc.f, C, x; accum=second)
87+
else
88+
89+
left = first(bc.args)
90+
right = last(bc.args)
91+
# handle annoyances with the pow operator
92+
if left isa Base.RefValue{typeof(^)}
93+
f = ^
94+
left = bc.args[2]
95+
right = valunwrap(right[])
96+
end
97+
# TODO: This if statement should probably be *inside* one of the inner ones to avoid duplication.
98+
if left === C
99+
if !(right isa Broadcast.Broadcasted)
100+
# This should be something of the form A .<op>= <expr> or A .= A .<op> <expr> which are equivalent.
101+
# this will be done by a subassign
102+
C[:,:, accum=bc.f] = right
103+
return C
104+
else
105+
# The form A .<op>= expr
106+
# but not of the form A .= C ... B.
107+
accum = bc.f
108+
f = right.f
109+
if length(right.args) == 1
110+
# Should be catching expressions of the form A .<op>= <op>.(B)
111+
subarg = first(right.args)
112+
if subarg isa Broadcast.Broadcasted
113+
subarg = copy(subarg)
114+
end
115+
return map!(f, C, subarg; accum)
116+
else
117+
# Otherwise we know there's two operands on the LHS so we have A .<op>= C .<op> B
118+
# Or a generalization with any compound *lazy* RHS.
119+
(subargleft, subargright) = right.args
120+
# subargleft and subargright are C and B respectively.
121+
# If they're further nested broadcasts we can't fuse them, so just copy.
122+
subargleft isa Broadcast.Broadcasted && (subargleft = copy(subargleft))
123+
subargright isa Broadcast.Broadcasted && (subargright = copy(subargright))
124+
if subargleft isa GBArray && subargright isa GBArray
125+
add = mutatingop(defaultadd(f))
126+
return add(C, subargleft, subargright, f; accum)
127+
else
128+
return map!(f, C, subargleft, subargright; accum)
129+
end
130+
end
131+
end
132+
else
133+
# Some expression of the form A .= C .<op> B or a generalization
134+
# excluding A .= A .<op> <expr>, since that is captured above.
135+
if left isa Broadcast.Broadcasted
136+
left = copy(left)
137+
end
138+
if right isa Broadcast.Broadcasted
139+
right = copy(right)
140+
end
141+
if left isa GBArray && right isa GBArray
142+
add = mutatingop(defaultadd(f))
143+
return add(C, left, right, f)
144+
else
145+
return map!(C, f, left, right; accum=second)
146+
end
147+
end
148+
end
149+
end
72150

73151
@inline function Base.copy(bc::Broadcast.Broadcasted{GBVectorStyle})
74152
f = bc.f

0 commit comments

Comments
 (0)