@@ -56,7 +56,7 @@ derivative/setup expressions.
56
56
This macro assumes complex functions are holomorphic. In general, for non-holomorphic
57
57
functions, the `frule` and `rrule` must be defined manually.
58
58
59
- If the derivative is one, (e.g. for identity functions) `true` can be used as the most
59
+ If the derivative is one, (e.g. for identity functions) `true` can be used as the most
60
60
general multiplicative identity.
61
61
62
62
The `@setup` argument can be elided if no setup code is need. In other
@@ -244,24 +244,13 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity)
244
244
esc (∂s_i)
245
245
end
246
246
end
247
- n∂s = length (_∂s)
248
247
249
- summed_∂_mul_Δs = if n∂s > 1
250
- # Explicit multiplication is only performed for the first pair
251
- # of partial and gradient.
252
- init_expr = :((* ). ($ (_∂s[1 ]), $ (Δs[1 ])))
253
-
254
- # Apply `muladd` iteratively.
255
- foldl (Iterators. drop (zip (_∂s, Δs), 1 ); init= init_expr) do ex, (∂s_i, Δs_i)
256
- :((muladd). ($ ∂s_i, $ Δs_i, $ ex))
257
- end
258
- else
259
- # Note: we don't want to do broadcasting with only 1 multiply (no `+`),
260
- # because some arrays overload multiply with scalar. Avoiding
261
- # broadcasting saves compilation time.
262
- :($ (_∂s[1 ]) * $ (Δs[1 ]))
248
+ # Apply `muladd` iteratively.
249
+ # Explicit multiplication is only performed for the first pair of partial and gradient.
250
+ init_expr = :(* ($ (_∂s[1 ]), $ (Δs[1 ])))
251
+ summed_∂_mul_Δs = foldl (Iterators. drop (zip (_∂s, Δs), 1 ); init= init_expr) do ex, (∂s_i, Δs_i)
252
+ :(muladd ($ ∂s_i, $ Δs_i, $ ex))
263
253
end
264
-
265
254
return :($ proj ($ summed_∂_mul_Δs))
266
255
end
267
256
0 commit comments