From ab28c42adfd02a1789875520169e23fcfad87122 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Wed, 9 Jul 2025 16:43:42 +0530 Subject: [PATCH] overlay, more calls to set_mooncakeoriginator_if_mooncake --- ext/DiffEqBaseMooncakeExt.jl | 2 ++ src/solve.jl | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/DiffEqBaseMooncakeExt.jl b/ext/DiffEqBaseMooncakeExt.jl index 29aff5271..16e4b46e5 100644 --- a/ext/DiffEqBaseMooncakeExt.jl +++ b/ext/DiffEqBaseMooncakeExt.jl @@ -34,6 +34,8 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive, typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator } +@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator() + function rrule!!( f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)}, X::CoDual{SciMLBase.ChainRulesOriginator} diff --git a/src/solve.jl b/src/solve.jl index 633e0246d..f20719867 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1093,9 +1093,9 @@ function solve(prob::AbstractDEProblem, args...; sensealg = nothing, p = p !== nothing ? p : prob.p if wrap isa Val{true} - wrap_sol(solve_up(prob, sensealg, u0, p, args...; kwargs...)) + wrap_sol(solve_up(prob, sensealg, u0, p, args...; originator=set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...)) else - solve_up(prob, sensealg, u0, p, args...; kwargs...) + solve_up(prob, sensealg, u0, p, args...; originator=set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...) end end