File tree Expand file tree Collapse file tree 1 file changed +5
-9
lines changed Expand file tree Collapse file tree 1 file changed +5
-9
lines changed Original file line number Diff line number Diff line change @@ -224,26 +224,22 @@ end
224224"""
225225function frule_propagation_expr (𝒟, Δs, ∂s)
226226 ∂s = map (esc, ∂s)
227- ∂_mul_Δs = [:(chain (@_thunk ( $ (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
227+ ∂_mul_Δs = [:(chain ($ ( _thunk (∂s[i])), $ (Δs[i]))) for i in 1 : length (∂s)]
228228 return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
229229end
230230
231231function rrule_propagation_expr (𝒟, Δs, ∂s)
232232 ∂s = map (esc, ∂s)
233- ∂_mul_Δs = [:(chain ($ (Δs[i]), @_thunk ( $ (∂s[i])))) for i in 1 : length (∂s)]
233+ ∂_mul_Δs = [:(chain ($ (Δs[i]), $ ( _thunk (∂s[i])))) for i in 1 : length (∂s)]
234234 return :(refine_differential ($ 𝒟, + ($ (∂_mul_Δs... ))))
235235end
236236
237237"""
238- @ _thunk body
238+ _thunk( body)
239239
240240Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
241241In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
242242"""
243- macro _thunk (body)
244- return _thunk (body)
245- end
246-
247243function _thunk (body)
248244 if body isa Expr
249245 if body. head == :call
261257thunk_assert_no_wirtinger (body) = quote
262258 Thunk (
263259 function ()
264- res = $ ( esc ( body))
265- res isa AbstractWirtinger && error ("""
260+ res = $ body
261+ res isa ChainRulesCore . AbstractWirtinger && error ("""
266262 Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule.
267263 Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""" )
268264 return res
You can’t perform that action at this time.
0 commit comments