Skip to content

Commit ce2cef3

Browse files
committed
get rid of branch that was only uses to change broadcasting
1 parent f7d4b3f commit ce2cef3

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

src/rule_definition_tools.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,24 +244,13 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
244244
esc(∂s_i)
245245
end
246246
end
247-
n∂s = length(_∂s)
248247

249-
summed_∂_mul_Δs = if n∂s > 1
250-
# Explicit multiplication is only performed for the first pair
251-
# of partial and gradient.
252-
init_expr = :(*($(_∂s[1]), $(Δs[1])))
253-
254-
# Apply `muladd` iteratively.
255-
foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i)
256-
:(muladd($∂s_i, $Δs_i, $ex))
257-
end
258-
else
259-
# Note: we don't want to do broadcasting with only 1 multiply (no `+`),
260-
# because some arrays overload multiply with scalar. Avoiding
261-
# broadcasting saves compilation time.
262-
:($(_∂s[1]) * $(Δs[1]))
248+
# Apply `muladd` iteratively.
249+
# Explicit multiplication is only performed for the first pair of partial and gradient.
250+
init_expr = :(*($(_∂s[1]), $(Δs[1])))
251+
summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i)
252+
:(muladd($∂s_i, $Δs_i, $ex))
263253
end
264-
265254
return :($proj($summed_∂_mul_Δs))
266255
end
267256

0 commit comments

Comments
 (0)