Skip to content

Commit ab28c42

Browse files
overlay, more calls to set_mooncakeoriginator_if_mooncake
1 parent f86c4ee commit ab28c42

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
3434
typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator
3535
}
3636

37+
@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator()
38+
3739
function rrule!!(
3840
f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)},
3941
X::CoDual{SciMLBase.ChainRulesOriginator}

src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,9 +1093,9 @@ function solve(prob::AbstractDEProblem, args...; sensealg = nothing,
10931093
p = p !== nothing ? p : prob.p
10941094

10951095
if wrap isa Val{true}
1096-
wrap_sol(solve_up(prob, sensealg, u0, p, args...; kwargs...))
1096+
wrap_sol(solve_up(prob, sensealg, u0, p, args...; originator=set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...))
10971097
else
1098-
solve_up(prob, sensealg, u0, p, args...; kwargs...)
1098+
solve_up(prob, sensealg, u0, p, args...; originator=set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...)
10991099
end
11001100
end
11011101

0 commit comments

Comments
 (0)