Skip to content

Commit 8272f85

Browse files
authored
Merge pull request #388 from Keno/kf/diffractorwip2
Get rid of useless broadcasting
2 parents bc830c6 + ce2cef3 commit 8272f85

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

src/rule_definition_tools.jl

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ derivative/setup expressions.
5656
This macro assumes complex functions are holomorphic. In general, for non-holomorphic
5757
functions, the `frule` and `rrule` must be defined manually.
5858
59-
If the derivative is one, (e.g. for identity functions) `true` can be used as the most
59+
If the derivative is one, (e.g. for identity functions) `true` can be used as the most
6060
general multiplicative identity.
6161
6262
The `@setup` argument can be elided if no setup code is need. In other
@@ -244,24 +244,13 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
244244
esc(∂s_i)
245245
end
246246
end
247-
n∂s = length(_∂s)
248247

249-
summed_∂_mul_Δs = if n∂s > 1
250-
# Explicit multiplication is only performed for the first pair
251-
# of partial and gradient.
252-
init_expr = :((*).($(_∂s[1]), $(Δs[1])))
253-
254-
# Apply `muladd` iteratively.
255-
foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i)
256-
:((muladd).($∂s_i, $Δs_i, $ex))
257-
end
258-
else
259-
# Note: we don't want to do broadcasting with only 1 multiply (no `+`),
260-
# because some arrays overload multiply with scalar. Avoiding
261-
# broadcasting saves compilation time.
262-
:($(_∂s[1]) * $(Δs[1]))
248+
# Apply `muladd` iteratively.
249+
# Explicit multiplication is only performed for the first pair of partial and gradient.
250+
init_expr = :(*($(_∂s[1]), $(Δs[1])))
251+
summed_∂_mul_Δs = foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i)
252+
:(muladd($∂s_i, $Δs_i, $ex))
263253
end
264-
265254
return :($proj($summed_∂_mul_Δs))
266255
end
267256

0 commit comments

Comments
 (0)