Skip to content

Commit b451452

Browse files
Merge pull request #1158 from SciML/mooncake
Setup Mooncake rrule
2 parents 4423738 + caa8b4f commit b451452

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.174.0"
4+
version = "6.175.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -45,6 +45,7 @@ GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
4545
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4646
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
4747
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
48+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4849
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4950
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5051
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -61,6 +62,7 @@ DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
6162
DiffEqBaseMPIExt = "MPI"
6263
DiffEqBaseMeasurementsExt = "Measurements"
6364
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
65+
DiffEqBaseMooncakeExt = "Mooncake"
6466
DiffEqBaseReverseDiffExt = "ReverseDiff"
6567
DiffEqBaseSparseArraysExt = "SparseArrays"
6668
DiffEqBaseTrackerExt = "Tracker"
@@ -91,14 +93,15 @@ MPI = "0.20"
9193
Markdown = "1.9"
9294
Measurements = "2"
9395
MonteCarloMeasurements = "1"
96+
Mooncake = "0.4"
9497
MuladdMacro = "0.2.1"
9598
Parameters = "0.12.0"
9699
PrecompileTools = "1"
97100
Printf = "1.9"
98101
RecursiveArrayTools = "3"
99102
Reexport = "1.0"
100103
ReverseDiff = "1"
101-
SciMLBase = "2.60.0"
104+
SciMLBase = "2.94.0"
102105
SciMLOperators = "0.3, 0.4, 1"
103106
SciMLStructures = "1.5"
104107
Setfield = "1"

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DiffEqBaseChainRulesCoreExt
22

33
using DiffEqBase
44
using DiffEqBase.SciMLBase
5-
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem
5+
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem, set_mooncakeoriginator_if_mooncake
66

77
import ChainRulesCore
88
import ChainRulesCore: NoTangent
@@ -15,7 +15,7 @@ function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
1515
u0, p, args...;
1616
kwargs...)
1717
DiffEqBase._solve_forward(
18-
prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
18+
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
1919
kwargs...)
2020
end
2121

@@ -24,7 +24,7 @@ function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEPro
2424
u0, p, args...;
2525
kwargs...)
2626
DiffEqBase._solve_adjoint(
27-
prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
27+
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
2828
kwargs...)
2929
end
3030

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module DiffEqBaseMooncakeExt
2+
3+
using DiffEqBase, Mooncake
4+
using DiffEqBase: SciMLBase
5+
using SciMLBase: ADOriginator, MooncakeOriginator
6+
Mooncake.@from_rrule(
7+
Mooncake.MinimalCtx,
8+
Tuple{
9+
typeof(DiffEqBase.solve_up),
10+
DiffEqBase.AbstractDEProblem,
11+
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
12+
Any,
13+
Any,
14+
Any,
15+
},
16+
true,
17+
)
18+
19+
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
20+
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator
21+
22+
end

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ timedepentdtmin(::Any, dtmin) = abs(dtmin)
5555

5656
maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)
5757

58+
set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x
59+
5860
function default_logger(logger)
5961
Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel && return nothing
6062

0 commit comments

Comments
 (0)