Skip to content

Commit 8d5ec61

Browse files
committed
add Mooncake stuff
1 parent 902fa17 commit 8d5ec61

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
4343
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
4444
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
4545
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
46+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4647
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
4748
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
4849
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
@@ -62,6 +63,7 @@ SciMLBaseRCallExt = "RCall"
6263
SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]
6364
SciMLBaseDistributionsExt = "Distributions"
6465
SciMLBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
66+
SciMLBaseMooncakeExt = "Mooncake"
6567
SciMLBaseReverseDiffExt = "ReverseDiff"
6668
SciMLBaseTrackerExt = "Tracker"
6769
SciMLBaseForwardDiffExt = "ForwardDiff"
@@ -88,6 +90,7 @@ MLStyle = "0.4.17"
8890
Makie = "0.20, 0.21, 0.22, 0.23, 0.24"
8991
Markdown = "1.10"
9092
Moshi = "0.3"
93+
Mooncake = "0.4"
9194
Measurements = "2"
9295
MonteCarloMeasurements = "1"
9396
PartialFunctions = "1.1"

ext/SciMLBaseMooncakeExt.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module SciMLBaseMooncakeExt
2+
3+
using SciMLBase, Mooncake
4+
using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
5+
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
6+
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
7+
NoPullback
8+
9+
@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
10+
@is_primitive MinimalCtx Tuple{
11+
typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator
12+
}
13+
14+
@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator()
15+
16+
function rrule!!(
17+
f::CoDual{typeof(SciMLBase.set_mooncakeoriginator_if_mooncake)},
18+
X::CoDual{SciMLBase.ChainRulesOriginator}
19+
)
20+
return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X)
21+
end
22+
23+
24+
25+
end

src/solve.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,4 +484,6 @@ totallength(x::AbstractArray) = __sum(totallength, x; init = 0)
484484

485485
_reshape(v, siz) = reshape(v, siz)
486486
_reshape(v::Number, siz) = v
487-
_reshape(v::AbstractSciMLScalarOperator, siz) = v
487+
_reshape(v::AbstractSciMLScalarOperator, siz) = v
488+
489+
set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = x

0 commit comments

Comments
 (0)