Skip to content

Commit 1816ff9

Browse files
committed
wip: reduce type instabilities in MooncakeExt
1 parent 0013272 commit 1816ff9

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

ext/DynamicExpressionsMooncakeExt.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,19 @@ _field_sym(i::Int) =
271271
_field_sym(::Type{Val{F}}) where {F} = _field_sym(F)
272272
_field_sym(::Val{F}) where {F} = _field_sym(F)
273273

274+
struct Pullback{T,field_sym,n_args}
275+
pt::T
276+
end
277+
function (pb::Pullback{T,field_sym,n_args})(Δy_rdata) where {T,field_sym,n_args}
278+
if field_sym === :val && !(Δy_rdata isa Mooncake.NoRData)
279+
pb.pt.val = Mooncake.increment_rdata!!(pb.pt.val, Δy_rdata)
280+
end
281+
return ntuple(_ -> Mooncake.NoRData(), Val(n_args))
282+
end
283+
274284
function _rrule_getfield_common(
275-
obj_cd::Mooncake.CoDual{N,TangentExprNode{Tv}}, field_sym::Symbol, n_args::Int
276-
) where {T,N<:AbstractExpressionNode{T},Tv}
285+
obj_cd::Mooncake.CoDual{N,TangentExprNode{Tv}}, ::Val{field_sym}, ::Val{n_args}
286+
) where {T,N<:AbstractExpressionNode{T},Tv,field_sym,n_args}
277287
p = Mooncake.primal(obj_cd)
278288
pt = Mooncake.tangent(obj_cd)
279289

@@ -296,14 +306,7 @@ function _rrule_getfield_common(
296306
Mooncake.fdata(tangent_for_field)
297307
end
298308
y_cd = Mooncake.CoDual(value_primal, fdata_for_output)
299-
300-
function pb(Δy_rdata)
301-
if field_sym === :val && !(Δy_rdata isa Mooncake.NoRData)
302-
pt.val = Mooncake.increment_rdata!!(pt.val, Δy_rdata)
303-
end
304-
return ntuple(_ -> Mooncake.NoRData(), n_args)
305-
end
306-
return y_cd, pb
309+
return y_cd, Pullback{typeof(pt),field_sym,n_args}(pt)
307310
end
308311

309312
# lgetfield(AEN, Val{field})
@@ -313,7 +316,7 @@ function Mooncake.rrule!!(
313316
obj_cd::Mooncake.CoDual{N,TangentExprNode{Tv}},
314317
vfield_cd::Mooncake.CoDual{Val{F},Mooncake.NoFData},
315318
) where {T,N<:AbstractExpressionNode{T},Tv,F}
316-
return _rrule_getfield_common(obj_cd, _field_sym(F), 3)
319+
return _rrule_getfield_common(obj_cd, Val(_field_sym(F)), Val(3))
317320
end
318321

319322
# getfield by Symbol
@@ -323,7 +326,7 @@ function Mooncake.rrule!!(
323326
obj_cd::Mooncake.CoDual{N,TangentExprNode{Tv}},
324327
sym_cd::Mooncake.CoDual{Symbol,Mooncake.NoFData},
325328
) where {T,N<:AbstractExpressionNode{T},Tv}
326-
return _rrule_getfield_common(obj_cd, Mooncake.primal(sym_cd), 3)
329+
return _rrule_getfield_common(obj_cd, Val(Mooncake.primal(sym_cd)), Val(3))
327330
end
328331

329332
# getfield by Int

0 commit comments

Comments
 (0)