@@ -271,9 +271,19 @@ _field_sym(i::Int) =
271271_field_sym (:: Type{Val{F}} ) where {F} = _field_sym (F)
272272_field_sym (:: Val{F} ) where {F} = _field_sym (F)
273273
274+ struct Pullback{T,field_sym,n_args}
275+ pt:: T
276+ end
277+ function (pb:: Pullback{T,field_sym,n_args} )(Δy_rdata) where {T,field_sym,n_args}
278+ if field_sym === :val && ! (Δy_rdata isa Mooncake. NoRData)
279+ pb. pt. val = Mooncake. increment_rdata!! (pb. pt. val, Δy_rdata)
280+ end
281+ return ntuple (_ -> Mooncake. NoRData (), Val (n_args))
282+ end
283+
274284function _rrule_getfield_common (
275- obj_cd:: Mooncake.CoDual{N,TangentExprNode{Tv}} , field_sym :: Symbol , n_args :: Int
276- ) where {T,N<: AbstractExpressionNode{T} ,Tv}
285+ obj_cd:: Mooncake.CoDual{N,TangentExprNode{Tv}} , :: Val{field_sym} , :: Val{n_args}
286+ ) where {T,N<: AbstractExpressionNode{T} ,Tv,field_sym,n_args }
277287 p = Mooncake. primal (obj_cd)
278288 pt = Mooncake. tangent (obj_cd)
279289
@@ -296,14 +306,7 @@ function _rrule_getfield_common(
296306 Mooncake. fdata (tangent_for_field)
297307 end
298308 y_cd = Mooncake. CoDual (value_primal, fdata_for_output)
299-
300- function pb (Δy_rdata)
301- if field_sym === :val && ! (Δy_rdata isa Mooncake. NoRData)
302- pt. val = Mooncake. increment_rdata!! (pt. val, Δy_rdata)
303- end
304- return ntuple (_ -> Mooncake. NoRData (), n_args)
305- end
306- return y_cd, pb
309+ return y_cd, Pullback {typeof(pt),field_sym,n_args} (pt)
307310end
308311
309312# lgetfield(AEN, Val{field})
@@ -313,7 +316,7 @@ function Mooncake.rrule!!(
313316 obj_cd:: Mooncake.CoDual{N,TangentExprNode{Tv}} ,
314317 vfield_cd:: Mooncake.CoDual{Val{F},Mooncake.NoFData} ,
315318) where {T,N<: AbstractExpressionNode{T} ,Tv,F}
316- return _rrule_getfield_common (obj_cd, _field_sym (F), 3 )
319+ return _rrule_getfield_common (obj_cd, Val ( _field_sym (F)), Val ( 3 ) )
317320end
318321
319322# getfield by Symbol
@@ -323,7 +326,7 @@ function Mooncake.rrule!!(
323326 obj_cd:: Mooncake.CoDual{N,TangentExprNode{Tv}} ,
324327 sym_cd:: Mooncake.CoDual{Symbol,Mooncake.NoFData} ,
325328) where {T,N<: AbstractExpressionNode{T} ,Tv}
326- return _rrule_getfield_common (obj_cd, Mooncake. primal (sym_cd), 3 )
329+ return _rrule_getfield_common (obj_cd, Val ( Mooncake. primal (sym_cd)), Val ( 3 ) )
327330end
328331
329332# getfield by Int
0 commit comments