diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 29aff5271..16e4b46e5 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -34,6 +34,8 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator } +@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator() + function rrule!!( f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)}, X::CoDual{SciMLBase.ChainRulesOriginator}