Skip to content

Commit a0fc73d

Browse files
Try fixing Mooncake originator
This doesn't trigger either, so it seems to be a bug in Mooncake chalk-lab/Mooncake.jl#587
1 parent b451452 commit a0fc73d

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@ ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle)
1212

1313
function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
1414
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
15-
u0, p, args...;
15+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
1616
kwargs...)
1717
DiffEqBase._solve_forward(
18-
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
18+
prob, sensealg, u0, p, originator, args...;
1919
kwargs...)
2020
end
2121

2222
function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem,
2323
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
24-
u0, p, args...;
24+
u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
2525
kwargs...)
2626
DiffEqBase._solve_adjoint(
27-
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
27+
prob, sensealg, u0, p, originator, args...;
2828
kwargs...)
2929
end
3030

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DiffEqBaseMooncakeExt
22

33
using DiffEqBase, Mooncake
44
using DiffEqBase: SciMLBase
5-
using SciMLBase: ADOriginator, MooncakeOriginator
5+
using SciMLBase: ADOriginator, MooncakeOriginator, ChainRulesOriginator
66
Mooncake.@from_rrule(
77
Mooncake.MinimalCtx,
88
Tuple{
@@ -17,6 +17,6 @@ Mooncake.@from_rrule(
1717
)
1818

1919
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
20-
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator
20+
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ChainRulesOriginator) = MooncakeOriginator()
2121

2222
end

src/solve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,14 +1168,15 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
11681168
p = p !== nothing ? p : prob.p
11691169

11701170
if wrap isa Val{true}
1171-
wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...))
1171+
wrap_sol(solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...))
11721172
else
1173-
solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, kwargs...)
1173+
solve_up(prob, sensealg, u0, p, args...; alias_u0 = alias_u0, originator = set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), kwargs...)
11741174
end
11751175
end
11761176

11771177
function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
1178-
args...; kwargs...)
1178+
args...; originator = SciMLBase.ChainRulesOriginator(),
1179+
kwargs...)
11791180
alg = extract_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
11801181
if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling
11811182
_prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0,

0 commit comments

Comments
 (0)