Skip to content

Commit a728566

Browse files
authored
Merge pull request #748 from SciML/myb/default_params_states
Introduce default_ps and default_u0
2 parents 19203d6 + 1d19cb8 commit a728566

File tree

10 files changed

+216
-185
lines changed

10 files changed

+216
-185
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,6 @@ function (f::AbstractSysToExpr)(O)
297297
end
298298
return build_expr(:call, Any[operation(O); f.(arguments(O))])
299299
end
300+
301+
get_default_p(sys) = sys.default_p
302+
get_default_u0(sys) = sys.default_u0

src/systems/control/controlsystem.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,33 @@ struct ControlSystem <: AbstractControlSystem
6868
systems: The internal systems
6969
"""
7070
systems::Vector{ControlSystem}
71+
"""
72+
default_u0: The default initial conditions to use when initial conditions
73+
are not supplied in `ODEProblem`.
74+
"""
75+
default_u0::Dict
76+
"""
77+
default_p: The default parameters to use when parameters are not supplied
78+
in `ODEProblem`.
79+
"""
80+
default_p::Dict
7181
end
7282

7383
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
74-
pins = [],
75-
observed = [],
76-
systems = ODESystem[],
77-
name=gensym(:ControlSystem))
84+
pins = [],
85+
observed = [],
86+
systems = ODESystem[],
87+
default_u0=Dict(),
88+
default_p=Dict(),
89+
name=gensym(:ControlSystem))
7890
iv′ = value(iv)
7991
dvs′ = value.(dvs)
8092
controls′ = value.(controls)
8193
ps′ = value.(ps)
94+
default_u0 isa Dict || (default_u0 = Dict(default_u0))
95+
default_p isa Dict || (default_p = Dict(default_p))
8296
ControlSystem(value(loss), deqs, iv′, dvs′, controls′,
83-
ps′, pins, observed, name, systems)
97+
ps′, pins, observed, name, systems, default_u0, default_p)
8498
end
8599

86100
struct ControlToExpr
@@ -102,7 +116,7 @@ end
102116
(f::ControlToExpr)(x::Sym) = x.name
103117

104118
function constructRadauIIA5(T::Type = Float64)
105-
sq6 = sqrt(6)
119+
sq6 = sqrt(convert(T, 6))
106120
A = [11//45-7sq6/360 37//225-169sq6/1800 -2//225+sq6/75
107121
37//225+169sq6/1800 11//45+7sq6/360 -2//225-sq6/75
108122
4//9-sq6/36 4//9+sq6/36 1//9]
@@ -111,7 +125,7 @@ function constructRadauIIA5(T::Type = Float64)
111125
A = map(T,A)
112126
α = map(T,α)
113127
c = map(T,c)
114-
return(DiffEqBase.ImplicitRKTableau(A,c,α,5))
128+
return DiffEqBase.ImplicitRKTableau(A,c,α,5)
115129
end
116130

117131

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 36 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,31 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
228228
!linenumbers ? striplines(ex) : ex
229229
end
230230

231+
function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
232+
version = nothing, tgrad=false,
233+
jac = false,
234+
checkbounds = false, sparse = false,
235+
simplify=false,
236+
linenumbers = true, parallel=SerialForm(),
237+
eval_expression = true,
238+
kwargs...)
239+
dvs = states(sys)
240+
ps = parameters(sys)
241+
u0map′ = lower_mapnames(u0map,sys.iv)
242+
u0 = varmap_to_vars(u0map′,dvs; defaults=get_default_u0(sys))
243+
244+
if !(parammap isa DiffEqBase.NullParameters)
245+
parammap′ = lower_mapnames(parammap)
246+
p = varmap_to_vars(parammap′,ps; defaults=get_default_p(sys))
247+
else
248+
p = ps
249+
end
250+
251+
f = constructor(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
252+
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
253+
sparse=sparse,eval_expression=eval_expression,kwargs...)
254+
return f, u0, p
255+
end
231256

232257
function ODEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
233258
ODEFunctionExpr{true}(sys, args...; kwargs...)
@@ -254,29 +279,8 @@ Generates an ODEProblem from an ODESystem and allows for automatically
254279
symbolically calculating numerical enhancements.
255280
"""
256281
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
257-
parammap=DiffEqBase.NullParameters();
258-
version = nothing, tgrad=false,
259-
jac = false,
260-
checkbounds = false, sparse = false,
261-
simplify=false,
262-
linenumbers = true, parallel=SerialForm(),
263-
eval_expression = true,
264-
kwargs...) where iip
265-
dvs = states(sys)
266-
ps = parameters(sys)
267-
u0map′ = lower_mapnames(u0map,sys.iv)
268-
u0 = varmap_to_vars(u0map′,dvs)
269-
270-
if !(parammap isa DiffEqBase.NullParameters)
271-
parammap′ = lower_mapnames(parammap)
272-
p = varmap_to_vars(parammap′,ps)
273-
else
274-
p = ps
275-
end
276-
277-
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
278-
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
279-
sparse=sparse,eval_expression=eval_expression,kwargs...)
282+
parammap=DiffEqBase.NullParameters();kwargs...) where iip
283+
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
280284
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
281285
end
282286

@@ -300,30 +304,12 @@ numerical enhancements.
300304
struct ODEProblemExpr{iip} end
301305

302306
function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
303-
parammap=DiffEqBase.NullParameters();
304-
version = nothing, tgrad=false,
305-
jac = false,
306-
checkbounds = false, sparse = false,
307-
simplify=false,
308-
linenumbers = false, parallel=SerialForm(),
309-
kwargs...) where iip
310-
311-
dvs = states(sys)
312-
ps = parameters(sys)
313-
u0map′ = lower_mapnames(u0map,sys.iv)
314-
u0 = varmap_to_vars(u0map′,dvs)
307+
parammap=DiffEqBase.NullParameters();
308+
kwargs...) where iip
315309

316-
if !(parammap isa DiffEqBase.NullParameters)
317-
parammap′ = lower_mapnames(parammap)
318-
p = varmap_to_vars(parammap′,ps)
319-
else
320-
p = ps
321-
end
310+
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
311+
linenumbers = get(kwargs, :linenumbers, true)
322312

323-
f = ODEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
324-
linenumbers=linenumbers,parallel=parallel,
325-
simplify=simplify,
326-
sparse=sparse,kwargs...)
327313
ex = quote
328314
f = $f
329315
u0 = $u0
@@ -358,19 +344,9 @@ Generates an SteadyStateProblem from an ODESystem and allows for automatically
358344
symbolically calculating numerical enhancements.
359345
"""
360346
function DiffEqBase.SteadyStateProblem{iip}(sys::AbstractODESystem,u0map,
361-
parammap=DiffEqBase.NullParameters();
362-
version = nothing, tgrad=false,
363-
jac = false,
364-
checkbounds = false, sparse = false,
365-
linenumbers = true, parallel=SerialForm(),
366-
kwargs...) where iip
367-
dvs = states(sys)
368-
ps = parameters(sys)
369-
u0 = varmap_to_vars(u0map,dvs)
370-
p = varmap_to_vars(parammap,ps)
371-
f = ODEFunction(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
372-
linenumbers=linenumbers,parallel=parallel,
373-
sparse=sparse,kwargs...)
347+
parammap=DiffEqBase.NullParameters();
348+
kwargs...) where iip
349+
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
374350
SteadyStateProblem(f,u0,p;kwargs...)
375351
end
376352

@@ -393,18 +369,9 @@ struct SteadyStateProblemExpr{iip} end
393369

394370
function SteadyStateProblemExpr{iip}(sys::AbstractODESystem,u0map,
395371
parammap=DiffEqBase.NullParameters();
396-
version = nothing, tgrad=false,
397-
jac = false,
398-
checkbounds = false, sparse = false,
399-
linenumbers = true, parallel=SerialForm(),
400372
kwargs...) where iip
401-
dvs = states(sys)
402-
ps = parameters(sys)
403-
u0 = varmap_to_vars(u0map,dvs)
404-
p = varmap_to_vars(parammap,ps)
405-
f = ODEFunctionExpr(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
406-
linenumbers=linenumbers,parallel=parallel,
407-
sparse=sparse,kwargs...)
373+
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
374+
linenumbers = get(kwargs, :linenumbers, true)
408375
ex = quote
409376
f = $f
410377
u0 = $u0

src/systems/diffeqs/odesystem.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,39 @@ struct ODESystem <: AbstractODESystem
6161
systems: The internal systems
6262
"""
6363
systems::Vector{ODESystem}
64+
"""
65+
default_u0: The default initial conditions to use when initial conditions
66+
are not supplied in `ODEProblem`.
67+
"""
68+
default_u0::Dict
69+
"""
70+
default_p: The default parameters to use when parameters are not supplied
71+
in `ODEProblem`.
72+
"""
73+
default_p::Dict
6474
end
6575

66-
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
76+
function ODESystem(
77+
deqs::AbstractVector{<:Equation}, iv, dvs, ps;
6778
pins = Num[],
6879
observed = Num[],
6980
systems = ODESystem[],
70-
name=gensym(:ODESystem))
81+
name=gensym(:ODESystem),
82+
default_u0=Dict(),
83+
default_p=Dict(),
84+
)
7185
iv′ = value(iv)
7286
dvs′ = value.(dvs)
7387
ps′ = value.(ps)
88+
89+
default_u0 isa Dict || (default_u0 = Dict(default_u0))
90+
default_p isa Dict || (default_p = Dict(default_p))
91+
7492
tgrad = RefValue(Vector{Num}(undef, 0))
7593
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
7694
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
7795
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
78-
ODESystem(deqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
96+
ODESystem(deqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p)
7997
end
8098

8199
var_from_nested_derivative(x, i=0) = (missing, missing)

src/systems/diffeqs/sdesystem.jl

Lines changed: 26 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,37 @@ struct SDESystem <: AbstractODESystem
6767
Systems: the internal systems
6868
"""
6969
systems::Vector{SDESystem}
70+
"""
71+
default_u0: The default initial conditions to use when initial conditions
72+
are not supplied in `ODEProblem`.
73+
"""
74+
default_u0::Dict
75+
"""
76+
default_p: The default parameters to use when parameters are not supplied
77+
in `ODEProblem`.
78+
"""
79+
default_p::Dict
7080
end
7181

7282
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
7383
pins = [],
7484
observed = [],
7585
systems = SDESystem[],
86+
default_u0=Dict(),
87+
default_p=Dict(),
7688
name = gensym(:SDESystem))
7789
iv′ = value(iv)
7890
dvs′ = value.(dvs)
7991
ps′ = value.(ps)
92+
93+
default_u0 isa Dict || (default_u0 = Dict(default_u0))
94+
default_p isa Dict || (default_p = Dict(default_p))
95+
8096
tgrad = RefValue(Vector{Num}(undef, 0))
8197
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
8298
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
8399
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
84-
SDESystem(deqs, neqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems)
100+
SDESystem(deqs, neqs, iv′, dvs′, ps′, pins, observed, tgrad, jac, Wfact, Wfact_t, name, systems, default_u0, default_p)
85101
end
86102

87103
function generate_diffusion_function(sys::SDESystem, dvs = sys.states, ps = sys.ps; kwargs...)
@@ -299,31 +315,11 @@ Generates an SDEProblem from an SDESystem and allows for automatically
299315
symbolically calculating numerical enhancements.
300316
"""
301317
function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBase.NullParameters();
302-
version = nothing, tgrad=false,
303-
jac = false, Wfact = false,
304-
checkbounds = false, sparse = false,
305-
sparsenoise = sparse,
306-
linenumbers = true, parallel=SerialForm(),
307-
eval_expression = true,
318+
sparsenoise = nothing,
308319
kwargs...) where iip
320+
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; kwargs...)
321+
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
309322

310-
dvs = states(sys)
311-
ps = parameters(sys)
312-
313-
u0map′ = lower_mapnames(u0map,sys.iv)
314-
u0 = varmap_to_vars(u0map′,dvs)
315-
316-
if !(parammap isa DiffEqBase.NullParameters)
317-
parammap′ = lower_mapnames(parammap)
318-
p = varmap_to_vars(parammap′,ps)
319-
else
320-
p = ps
321-
end
322-
323-
f = SDEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,
324-
checkbounds=checkbounds,
325-
linenumbers=linenumbers,parallel=parallel,
326-
sparse=sparse, eval_expression=eval_expression,kwargs...)
327323
if typeof(sys.noiseeqs) <: AbstractVector
328324
noise_rate_prototype = nothing
329325
elseif sparsenoise
@@ -358,29 +354,13 @@ numerical enhancements.
358354
struct SDEProblemExpr{iip} end
359355

360356
function SDEProblemExpr{iip}(sys::SDESystem,u0map,tspan,
361-
parammap=DiffEqBase.NullParameters();
362-
version = nothing, tgrad=false,
363-
jac = false, Wfact = false,
364-
checkbounds = false, sparse = false,
365-
linenumbers = false, parallel=SerialForm(),
366-
kwargs...) where iip
367-
dvs = states(sys)
368-
ps = parameters(sys)
357+
parammap=DiffEqBase.NullParameters();
358+
sparsenoise = nothing,
359+
kwargs...) where iip
360+
f, u0, p = process_DEProblem(SDEFunctionExpr{iip}, sys, u0map, parammap; kwargs...)
361+
linenumbers = get(kwargs, :linenumbers, true)
362+
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
369363

370-
u0map′ = lower_mapnames(u0map,sys.iv)
371-
u0 = varmap_to_vars(u0map′,dvs)
372-
373-
if !(parammap isa DiffEqBase.NullParameters)
374-
parammap′ = lower_mapnames(parammap)
375-
p = varmap_to_vars(parammap′,ps)
376-
else
377-
p = ps
378-
end
379-
380-
f = SDEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,
381-
Wfact=Wfact,checkbounds=checkbounds,
382-
linenumbers=linenumbers,parallel=parallel,
383-
sparse=sparse,kwargs...)
384364
if typeof(sys.noiseeqs) <: AbstractVector
385365
noise_rate_prototype = nothing
386366
elseif sparsenoise

0 commit comments

Comments
 (0)