File tree Expand file tree Collapse file tree 2 files changed +19
-10
lines changed Expand file tree Collapse file tree 2 files changed +19
-10
lines changed Original file line number Diff line number Diff line change 296296"""
297297 @thunk body
298298
299- Returns `Thunk(() -> body)`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
300- In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
299+ Returns `Thunk(() -> body)`
301300"""
302301macro thunk (body)
303- if body isa Expr && body. head == :call
304- fname = body. args[1 ]
305- if fname in (:Wirtinger , :ComplexGradient )
306- return :($ fname ($ ((:(@thunk $ i) for i in body. args[2 : end ]). .. )))
307- end
308- end
309302 return :(Thunk (() -> $ (esc (body))))
310303end
311304
Original file line number Diff line number Diff line change @@ -224,16 +224,32 @@ end
224224"""
225225function frule_propagation_expr (𝒟, Δs, ∂s)
226226 ∂s = map (esc, ∂s)
227- ∂_mul_Δs = [:(chain (@thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
227+ ∂_mul_Δs = [:(chain (@_thunk ($ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228228 return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
229229end
230230
231231function rrule_propagation_expr (𝒟, Δs, ∂s)
232232 ∂s = map (esc, ∂s)
233- ∂_mul_Δs = [:(chain ($ (Δs[i]), @thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
233+ ∂_mul_Δs = [:(chain ($ (Δs[i]), @_thunk ($ (∂s[i])))) for i in 1 : length (∂s)]
234234 return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
235235end
236236
237+ """
238+ @_thunk body
239+
240+ Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241+ In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242+ """
243+ macro _thunk (body)
244+ if body isa Expr && body. head == :call
245+ fname = body. args[1 ]
246+ if fname in (:Wirtinger , :ComplexGradient )
247+ return :($ fname ($ ((:(@thunk $ (esc (i))) for i in body. args[2 : end ]). .. )))
248+ end
249+ end
250+ return :(@thunk $ (esc (body)))
251+ end
252+
237253"""
238254 propagator_name(f, propname)
239255
You can’t perform that action at this time.
0 commit comments