Skip to content

Commit 63ced69

Browse files
committed
Using DiffEqBase in Mooncake extension
1 parent a34686f commit 63ced69

15 files changed

+95
-79
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3636
[weakdeps]
3737
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3838
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
39+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
3940
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4041
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4142
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
@@ -58,7 +59,7 @@ SciMLBaseForwardDiffExt = "ForwardDiff"
5859
SciMLBaseMLStyleExt = "MLStyle"
5960
SciMLBaseMakieExt = "Makie"
6061
SciMLBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
61-
SciMLBaseMooncakeExt = "Mooncake"
62+
SciMLBaseMooncakeExt = ["Mooncake", "DiffEqBase"]
6263
SciMLBasePartialFunctionsExt = "PartialFunctions"
6364
SciMLBasePyCallExt = "PyCall"
6465
SciMLBasePythonCallExt = "PythonCall"

ext/SciMLBaseDistributionsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ using Distributions, SciMLBase
55
SciMLBase.handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0)
66
SciMLBase.isdistribution(_u0::Distributions.Sampleable) = true
77

8-
end
8+
end

ext/SciMLBaseForwardDiffExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ module SciMLBaseForwardDiffExt
33
using SciMLBase, ForwardDiff
44
using ArrayInterface
55

6-
import SciMLBase:
7-
wrapfun_oop, wrapfun_iip, isdualtype, value, DualEltypeChecker,
8-
AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem,
9-
ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem,
10-
RecursiveArrayTools, totallength, sse, anyeltypedual, reduce_tup
6+
import SciMLBase:
7+
wrapfun_oop, wrapfun_iip, isdualtype, value, DualEltypeChecker,
8+
AbstractTimeseriesSolution, NonlinearProblem,
9+
NonlinearLeastSquaresProblem,
10+
ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem,
11+
RecursiveArrayTools, totallength, sse, anyeltypedual, reduce_tup
1112

1213
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
1314
isdualtype(::Type{<:ForwardDiff.Dual}) = true

ext/SciMLBaseMonteCarloMeasurementsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ function SciMLBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles
3232
end
3333
SciMLBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
3434

35-
end
35+
end

ext/SciMLBaseMooncakeExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module SciMLBaseMooncakeExt
22

33
using SciMLBase, Mooncake
44
using SciMLBase: ADOriginator, ChainRulesOriginator, MooncakeOriginator
5+
using DiffEqBase: DiffEqBase
56
import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
67
@from_rrule, @zero_adjoint, @mooncake_overlay, MinimalCtx,
78
NoPullback
@@ -20,6 +21,4 @@ function rrule!!(
2021
return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X)
2122
end
2223

23-
24-
25-
end
24+
end

ext/SciMLBaseReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ function SciMLBase.promote_u0(
5454
end
5555
SciMLBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
5656

57-
end
57+
end

ext/SciMLBaseTrackerExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,4 @@ SciMLBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p
3333

3434
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))
3535

36-
37-
end
36+
end

src/clock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
$(TYPEDEF)
131131
132132
A struct representing the operation of indexing a clock to obtain a subset of the time
133-
points at which it ticked. The actual list of time points depends on the tick instances
133+
points at which it ticked. The actual list of time points depends on the tick instances
134134
on which the clock was ticking, and can be obtained via `canonicalize_indexed_clock`
135135
by providing a timeseries solution object.
136136

src/debug.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ expression. Two common reasons for this issue are:
5858

5959
function __init__()
6060
Base.Experimental.register_error_hint(DomainError) do io, e
61-
if e isa DomainError && occursin("will only return a complex result if called with a complex argument. Try ", e.msg)
61+
if e isa DomainError &&
62+
occursin("will only return a complex result if called with a complex argument. Try ", e.msg)
6263
println(io, DOMAINERROR_COMPLEX_MSG)
6364
end
6465
end

src/errors.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ const allowedkeywords = (:dense,
9090
# Parameter estimation with BVP
9191
:fit_parameters)
9292

93-
9493
const KWARGWARN_MESSAGE = """
9594
Unrecognized keyword arguments found.
9695
The only allowed keyword arguments to `solve` are:
@@ -469,4 +468,4 @@ struct LateBindingTstopsNotSupportedError <: Exception end
469468
function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
470469
println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE)
471470
println(io, TruncatedStacktraces.VERBOSE_MSG)
472-
end
471+
end

0 commit comments

Comments
 (0)