Skip to content

Commit ceed5e7

Browse files
committed
Allow fit_parameters keyword
1 parent f36d93c commit ceed5e7

File tree

5 files changed

+24
-23
lines changed

5 files changed

+24
-23
lines changed

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module DiffEqBaseChainRulesCoreExt
22

33
using DiffEqBase
44
using DiffEqBase.SciMLBase
5-
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem, set_mooncakeoriginator_if_mooncake
5+
import DiffEqBase: numargs, AbstractSensitivityAlgorithm, AbstractDEProblem,
6+
set_mooncakeoriginator_if_mooncake
67

78
import ChainRulesCore
89
import ChainRulesCore: NoTangent
@@ -15,7 +16,8 @@ function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
1516
u0, p, args...;
1617
kwargs...)
1718
DiffEqBase._solve_forward(
18-
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
19+
prob, sensealg, u0, p,
20+
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
1921
kwargs...)
2022
end
2123

@@ -24,7 +26,8 @@ function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEPro
2426
u0, p, args...;
2527
kwargs...)
2628
DiffEqBase._solve_adjoint(
27-
prob, sensealg, u0, p, set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
29+
prob, sensealg, u0, p,
30+
set_mooncakeoriginator_if_mooncake(SciMLBase.ChainRulesOriginator()), args...;
2831
kwargs...)
2932
end
3033

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,29 @@ module DiffEqBaseMooncakeExt
33
using DiffEqBase, Mooncake
44
using DiffEqBase: SciMLBase
55
using SciMLBase: ADOriginator, MooncakeOriginator
6-
Mooncake.@from_rrule(
7-
Mooncake.MinimalCtx,
6+
Mooncake.@from_rrule(Mooncake.MinimalCtx,
87
Tuple{
98
typeof(DiffEqBase.solve_up),
109
DiffEqBase.AbstractDEProblem,
1110
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
1211
Any,
1312
Any,
14-
Any,
13+
Any
1514
},
16-
true,
17-
)
15+
true,)
1816

1917
# Dispatch for auto-alg
20-
Mooncake.@from_rrule(
21-
Mooncake.MinimalCtx,
18+
Mooncake.@from_rrule(Mooncake.MinimalCtx,
2219
Tuple{
2320
typeof(DiffEqBase.solve_up),
2421
DiffEqBase.AbstractDEProblem,
2522
Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm},
2623
Any,
27-
Any,
24+
Any
2825
},
29-
true,
30-
)
26+
true,)
3127

3228
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
3329
Mooncake.@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::ADOriginator) = MooncakeOriginator
3430

35-
end
31+
end

src/solve.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ const allowedkeywords = (:dense,
9999
# Termination condition for solvers
100100
:termination_condition,
101101
# For AbstractAliasSpecifier
102-
:alias)
102+
:alias,
103+
# Parameter estimation with BVP
104+
:fit_parameters)
103105

104106
const KWARGWARN_MESSAGE = """
105107
Unrecognized keyword arguments found.
@@ -541,7 +543,8 @@ end
541543
Get the innermost index provider using `SII.symbolic_container`.
542544
"""
543545
function _get_root_indp(indp)
544-
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp
546+
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) &&
547+
(sc = SII.symbolic_container(indp)) !== indp
545548
return _get_root_indp(sc)
546549
end
547550
return indp
@@ -748,7 +751,7 @@ function build_null_solution(prob::AbstractDEProblem, args...;
748751

749752
prob, success = hack_null_solution_init(prob)
750753
retcode = success ? ReturnCode.Success : ReturnCode.InitialFailure
751-
build_solution(prob, nothing, ts, timeseries; dense=true, retcode)
754+
build_solution(prob, nothing, ts, timeseries; dense = true, retcode)
752755
end
753756

754757
function build_null_solution(
@@ -1139,7 +1142,6 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
11391142
sensealg = prob.kwargs[:sensealg]
11401143
end
11411144

1142-
11431145
if haskey(prob.kwargs, :alias_u0)
11441146
@warn "The `alias_u0` keyword argument is deprecated. Please use a NonlinearAliasSpecifier, e.g. `alias = NonlinearAliasSpecifier(alias_u0 = true)`."
11451147
alias_spec = NonlinearAliasSpecifier(alias_u0 = prob.kwargs[:alias_u0])
@@ -1152,7 +1154,7 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
11521154
alias_spec = NonlinearAliasSpecifier(alias = prob.kwargs[:alias])
11531155
elseif haskey(kwargs, :alias) && kwargs[:alias] isa Bool
11541156
alias_spec = NonlinearAliasSpecifier(alias = kwargs[:alias])
1155-
end
1157+
end
11561158

11571159
if haskey(prob.kwargs, :alias) && prob.kwargs[:alias] isa NonlinearAliasSpecifier
11581160
alias_spec = prob.kwargs[:alias]

test/downstream/unitful.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ intg = init(prob, Tsit5())
55
@test_nowarn step!(intg, 0.02u"s", true)
66

77
@test DiffEqBase.unitfulvalue(u"1/s") == u"1/s"
8-
@test DiffEqBase.value(ForwardDiff.Dual(1)*u"1/s") == 1
9-
@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1)*u"1/s") == u"1/s"
8+
@test DiffEqBase.value(ForwardDiff.Dual(1) * u"1/s") == 1
9+
@test DiffEqBase.unitfulvalue(ForwardDiff.Dual(1) * u"1/s") == u"1/s"

test/static/static_checks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DiffEqBase, ComponentArrays, AllocCheck, Test
22

3-
u = ComponentArray(x=1.0, y=0.0, z=0.0)
3+
u = ComponentArray(x = 1.0, y = 0.0, z = 0.0)
44
t = 0.0
5-
@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0
5+
@test length(check_allocs(DiffEqBase.ODE_DEFAULT_NORM, (typeof(u), typeof(t)))) == 0

0 commit comments

Comments
 (0)