Skip to content

Commit 89c03ea

Browse files
authored
Merge pull request #905 from SciML/myb/defaults
Use `defaults` and variable mapping clean up
2 parents be03fe7 + 6669043 commit 89c03ea

24 files changed

+225
-335
lines changed

docs/src/basics/AbstractSystem.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ same keyword arguments, which are:
2222

2323
- `system`: This is used for specifying subsystems for hierarchical modeling with
2424
reusable components. For more information, see the [components page](@ref components)
25-
- Defaults: Keyword arguments like `default_u0` are used for specifying default
25+
- Defaults: Keyword arguments like `defaults` are used for specifying default
2626
values which are used. If a value is not given at the `SciMLProblem` construction
2727
time, its numerical value will be the default.
2828

@@ -49,9 +49,7 @@ Optionally, a system could have:
4949

5050
- `observed(sys)`: All observed equations of the system and its subsystems.
5151
- `get_observed(sys)`: Observed equations of the current-level system.
52-
- `get_default_u0(sys)`: A `Dict` that maps states into their default initial
53-
condition.
54-
- `get_default_p(sys)`: A `Dict` that maps parameters into their default value.
52+
- `get_defaults(sys)`: A `Dict` that maps variables into their default values.
5553
- `independent_variable(sys)`: The independent variable of a system.
5654
- `get_noiseeqs(sys)`: Noise equations of the current-level system.
5755

@@ -132,7 +130,7 @@ u0 = [
132130
## Default Value Handling
133131

134132
The `AbstractSystem` types allow for specifying default values, for example
135-
`default_p` inside of them. At problem construction time, these values are merged
133+
`defaults` inside of them. At problem construction time, these values are merged
136134
into the value maps, where for any repeats the value maps override the default.
137135
In addition, defaults of a higher level in the system override the defaults of
138-
a lower level in the system.
136+
a lower level in the system.

docs/src/tutorials/acausal_components.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using ModelingToolkit, Plots, DifferentialEquations
1919
# Basic electric components
2020
function Pin(;name)
2121
@variables v(t) i(t)
22-
ODESystem(Equation[], t, [v, i], [], name=name, default_u0=[v=>1.0, i=>1.0])
22+
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])
2323
end
2424

2525
function Ground(;name)
@@ -39,7 +39,7 @@ function Resistor(;name, R = 1.0)
3939
0 ~ p.i + n.i
4040
v ~ p.i * R
4141
]
42-
ODESystem(eqs, t, [v], [R], systems=[p, n], default_p=Dict(R => val), name=name)
42+
ODESystem(eqs, t, [v], [R], systems=[p, n], defaults=Dict(R => val), name=name)
4343
end
4444

4545
function Capacitor(; name, C = 1.0)
@@ -54,7 +54,7 @@ function Capacitor(; name, C = 1.0)
5454
0 ~ p.i + n.i
5555
D(v) ~ p.i / C
5656
]
57-
ODESystem(eqs, t, [v], [C], systems=[p, n], default_p=Dict(C => val), name=name)
57+
ODESystem(eqs, t, [v], [C], systems=[p, n], defaults=Dict(C => val), name=name)
5858
end
5959

6060
function ConstantVoltage(;name, V = 1.0)
@@ -66,7 +66,7 @@ function ConstantVoltage(;name, V = 1.0)
6666
V ~ p.v - n.v
6767
0 ~ p.i + n.i
6868
]
69-
ODESystem(eqs, t, [], [V], systems=[p, n], default_p=Dict(V => val), name=name)
69+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
7070
end
7171

7272
R = 1.0
@@ -122,7 +122,7 @@ component to simply be the values there:
122122
```julia
123123
function Pin(;name)
124124
@variables v(t) i(t)
125-
ODESystem(Equation[], t, [v, i], [], name=name, default_u0=[v=>1.0, i=>1.0])
125+
ODESystem(Equation[], t, [v, i], [], name=name, defaults=[v=>1.0, i=>1.0])
126126
end
127127
```
128128

@@ -176,11 +176,11 @@ function Resistor(;name, R = 1.0)
176176
0 ~ p.i + n.i
177177
v ~ p.i * R
178178
]
179-
ODESystem(eqs, t, [v], [R], systems=[p, n], default_p=Dict(R => val), name=name)
179+
ODESystem(eqs, t, [v], [R], systems=[p, n], defaults=Dict(R => val), name=name)
180180
end
181181
```
182182

183-
Notice that we have created this system with a `default_p` for the resistor's
183+
Notice that we have created this system with a `defaults` for the resistor's
184184
resistance. By doing so, if the resistance of this resistor is not overridden
185185
by a higher level default or overridden at `ODEProblem` construction time, this
186186
will be the value of the resistance.
@@ -200,7 +200,7 @@ function Capacitor(; name, C = 1.0)
200200
0 ~ p.i + n.i
201201
D(v) ~ p.i / C
202202
]
203-
ODESystem(eqs, t, [v], [C], systems=[p, n], default_p=Dict(C => val), name=name)
203+
ODESystem(eqs, t, [v], [C], systems=[p, n], defaults=Dict(C => val), name=name)
204204
end
205205
```
206206

@@ -219,7 +219,7 @@ function ConstantVoltage(;name, V = 1.0)
219219
V ~ p.v - n.v
220220
0 ~ p.i + n.i
221221
]
222-
ODESystem(eqs, t, [], [V], systems=[p, n], default_p=Dict(V => val), name=name)
222+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
223223
end
224224
```
225225

docs/src/tutorials/tearing_parallelism.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const t = Sym{ModelingToolkit.Parameter{Real}}(:t)
4141
const D = Differential(t)
4242
function Pin(;name)
4343
@variables v(t) i(t)
44-
ODESystem(Equation[], t, [v, i], [], name=name, default_u0=Dict([v=>1.0, i=>1.0]))
44+
ODESystem(Equation[], t, [v, i], [], name=name, defaults=Dict([v=>1.0, i=>1.0]))
4545
end
4646

4747
function Ground(;name)
@@ -59,12 +59,12 @@ function ConstantVoltage(;name, V = 1.0)
5959
V ~ p.v - n.v
6060
0 ~ p.i + n.i
6161
]
62-
ODESystem(eqs, t, [], [V], systems=[p, n], default_p=Dict(V => val), name=name)
62+
ODESystem(eqs, t, [], [V], systems=[p, n], defaults=Dict(V => val), name=name)
6363
end
6464

6565
function HeatPort(;name)
6666
@variables T(t) Q_flow(t)
67-
return ODESystem(Equation[], t, [T, Q_flow], [], default_u0=Dict(T=>293.15, Q_flow=>0.0), name=name)
67+
return ODESystem(Equation[], t, [T, Q_flow], [], defaults=Dict(T=>293.15, Q_flow=>0.0), name=name)
6868
end
6969

7070
function HeatingResistor(;name, R=1.0, TAmbient=293.15, alpha=1.0)
@@ -83,8 +83,10 @@ function HeatingResistor(;name, R=1.0, TAmbient=293.15, alpha=1.0)
8383
]
8484
ODESystem(
8585
eqs, t, [v, RTherm], [R, TAmbient, alpha], systems=[p, n, h],
86-
default_p=Dict(R=>R_val, TAmbient=>TAmbient_val, alpha=>alpha_val),
87-
default_u0=Dict(v=>0.0, RTherm=>R_val),
86+
defaults=Dict(
87+
R=>R_val, TAmbient=>TAmbient_val, alpha=>alpha_val,
88+
v=>0.0, RTherm=>R_val
89+
),
8890
name=name,
8991
)
9092
end
@@ -99,7 +101,7 @@ function HeatCapacitor(;name, rho=8050, V=1, cp=460, TAmbient=293.15)
99101
]
100102
ODESystem(
101103
eqs, t, [], [rho, V, cp], systems=[h],
102-
default_p=Dict(rho=>rho_val, V=>V_val, cp=>cp_val),
104+
defaults=Dict(rho=>rho_val, V=>V_val, cp=>cp_val),
103105
name=name,
104106
)
105107
end
@@ -117,8 +119,7 @@ function Capacitor(;name, C = 1.0)
117119
]
118120
ODESystem(
119121
eqs, t, [v], [C], systems=[p, n],
120-
default_u0=Dict(v => 0.0),
121-
default_p=Dict(C => val),
122+
defaults=Dict(v => 0.0, C => val),
122123
name=name
123124
)
124125
end
@@ -159,7 +160,7 @@ end
159160
eqs = [
160161
D(E) ~ sum(((i, sys),)->getproperty(sys, Symbol(:resistor, i)).h.Q_flow, enumerate(rc_systems))
161162
]
162-
big_rc = ODESystem(eqs, t, [], [], systems=rc_systems, default_u0=Dict(E=>0.0))
163+
big_rc = ODESystem(eqs, t, [], [], systems=rc_systems, defaults=Dict(E=>0.0))
163164
```
164165

165166
Now let's say we want to expose a bit more parallelism via running tearing.

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using ModelingToolkit: ODESystem, var_from_nested_derivative, Differential,
1616
states, equations, vars, Symbolic, diff2term, value,
1717
operation, arguments, Sym, Term, simplify, solve_for,
1818
isdiffeq, isdifferential,
19-
get_structure, get_reduced_states, default_u0, default_p
19+
get_structure, get_reduced_states, defaults
2020

2121
using ModelingToolkit.BipartiteGraphs
2222
using ModelingToolkit.SystemStructures

src/structural_transformation/codegen.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function gen_nlsolve(sys, eqs, vars)
131131
allvars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
132132
params = setdiff(allvars, vars)
133133

134-
u0map = default_u0(sys)
134+
u0map = defaults(sys)
135135
# splatting to tighten the type
136136
u0 = [map(var->get(u0map, var, 1e-3), vars)...]
137137
# specialize on the scalar case
@@ -338,22 +338,11 @@ function ODAEProblem{iip}(
338338
s = structure(sys)
339339
@unpack fullvars = s
340340
dvs = fullvars[diffvars_range(s)]
341-
defaults = merge(default_p(sys), default_u0(sys))
342-
u0map′ = ModelingToolkit.lower_mapnames(u0map, independent_variable(sys))
343-
u0 = ModelingToolkit.varmap_to_vars(u0map′, dvs; defaults=defaults)
344-
345341
ps = parameters(sys)
346-
if parammap isa DiffEqBase.NullParameters && isempty(default_p(sys))
347-
isempty(ps) || throw(ArgumentError("The model has non-empty parameters but no parameters are specified in the problem."))
348-
p = parammap
349-
else
350-
if parammap isa DiffEqBase.NullParameters
351-
pp = Pair[]
352-
else
353-
pp = ModelingToolkit.lower_mapnames(parammap)
354-
end
355-
p = ModelingToolkit.varmap_to_vars(pp, ps; defaults=defaults)
356-
end
342+
defs = defaults(sys)
343+
344+
u0 = ModelingToolkit.varmap_to_vars(u0map, dvs; defaults=defs)
345+
p = ModelingToolkit.varmap_to_vars(parammap, ps; defaults=defs)
357346

358347
ODEProblem{iip}(build_torn_function(sys; kw...), u0, tspan, p; kw...)
359348
end

src/systems/abstractsystem.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ for prop in [
140140
:iv
141141
:states
142142
:ps
143-
:default_p
144-
:default_u0
143+
:defaults
145144
:observed
146145
:tgrad
147146
:jac
@@ -155,6 +154,10 @@ for prop in [
155154
:controls
156155
:loss
157156
:reduced_states
157+
:bcs
158+
:domain
159+
:depvars
160+
:indvars
158161
]
159162
fname1 = Symbol(:get_, prop)
160163
fname2 = Symbol(:has_, prop)
@@ -253,13 +256,13 @@ function Base.setproperty!(sys::AbstractSystem, prop::Symbol, val)
253256
idx = findfirst(s->getname(s) == prop, params);
254257
idx !== nothing;
255258
)
256-
get_default_p(sys)[params[idx]] = value(val)
259+
get_defaults(sys)[params[idx]] = value(val)
257260
elseif (
258261
sts = states(sys);
259262
idx = findfirst(s->getname(s) == prop, sts);
260263
idx !== nothing;
261264
)
262-
get_default_u0(sys)[sts[idx]] = value(val)
265+
get_defaults(sys)[sts[idx]] = value(val)
263266
else
264267
setfield!(sys, prop, val)
265268
end
@@ -280,14 +283,9 @@ end
280283
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
281284
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
282285

283-
function namespace_default_u0(sys)
284-
d_u0 = default_u0(sys)
285-
Dict(states(sys, k) => namespace_expr(d_u0[k], nameof(sys), independent_variable(sys)) for k in keys(d_u0))
286-
end
287-
288-
function namespace_default_p(sys)
289-
d_p = default_p(sys)
290-
Dict(parameters(sys, k) => namespace_expr(d_p[k], nameof(sys), independent_variable(sys)) for k in keys(d_p))
286+
function namespace_defaults(sys)
287+
defs = defaults(sys)
288+
Dict((isparameter(k) ? parameters(sys, k) : states(sys, k)) => namespace_expr(defs[k], nameof(sys), independent_variable(sys)) for k in keys(defs))
291289
end
292290

293291
function namespace_equations(sys::AbstractSystem)
@@ -344,16 +342,12 @@ function observed(sys::AbstractSystem)
344342
init=Equation[])]
345343
end
346344

347-
function default_u0(sys::AbstractSystem)
345+
Base.@deprecate default_u0(x) defaults(x) false
346+
Base.@deprecate default_p(x) defaults(x) false
347+
function defaults(sys::AbstractSystem)
348348
systems = get_systems(sys)
349-
d_u0 = get_default_u0(sys)
350-
isempty(systems) ? d_u0 : mapreduce(namespace_default_u0, merge, systems; init=d_u0)
351-
end
352-
353-
function default_p(sys::AbstractSystem)
354-
systems = get_systems(sys)
355-
d_p = get_default_p(sys)
356-
isempty(systems) ? d_p : mapreduce(namespace_default_p, merge, systems; init=d_p)
349+
defs = get_defaults(sys)
350+
isempty(systems) ? defs : mapreduce(namespace_defaults, merge, systems; init=defs)
357351
end
358352

359353
states(sys::AbstractSystem, v) = renamespace(nameof(sys), v)
@@ -398,9 +392,13 @@ function (f::AbstractSysToExpr)(O)
398392
return build_expr(:call, Any[operation(O); f.(arguments(O))])
399393
end
400394

401-
function Base.show(io::IO, sys::AbstractSystem)
395+
function Base.show(io::IO, ::MIME"text/plain", sys::AbstractSystem)
402396
eqs = equations(sys)
403-
Base.printstyled(io, "Model $(nameof(sys)) with $(length(eqs)) equations\n"; bold=true)
397+
if eqs isa AbstractArray
398+
Base.printstyled(io, "Model $(nameof(sys)) with $(length(eqs)) equations\n"; bold=true)
399+
else
400+
Base.printstyled(io, "Model $(nameof(sys))\n"; bold=true)
401+
end
404402
# The reduced equations are usually very long. It's not that useful to print
405403
# them.
406404
#Base.print_matrix(io, eqs)
@@ -413,13 +411,13 @@ function Base.show(io::IO, sys::AbstractSystem)
413411
Base.printstyled(io, "States ($nvars):"; bold=true)
414412
nrows = min(nvars, limit ? rows : nvars)
415413
limited = nrows < length(vars)
416-
d_u0 = has_default_u0(sys) ? default_u0(sys) : nothing
414+
defs = has_defaults(sys) ? defaults(sys) : nothing
417415
for i in 1:nrows
418416
s = vars[i]
419417
print(io, "\n ", s)
420418

421-
if d_u0 !== nothing
422-
val = get(d_u0, s, nothing)
419+
if defs !== nothing
420+
val = get(defs, s, nothing)
423421
if val !== nothing
424422
print(io, " [defaults to $val]")
425423
end
@@ -432,13 +430,12 @@ function Base.show(io::IO, sys::AbstractSystem)
432430
Base.printstyled(io, "Parameters ($nvars):"; bold=true)
433431
nrows = min(nvars, limit ? rows : nvars)
434432
limited = nrows < length(vars)
435-
d_p = has_default_p(sys) ? default_p(sys) : nothing
436433
for i in 1:nrows
437434
s = vars[i]
438435
print(io, "\n ", s)
439436

440-
if d_p !== nothing
441-
val = get(d_p, s, nothing)
437+
if defs !== nothing
438+
val = get(defs, s, nothing)
442439
if val !== nothing
443440
print(io, " [defaults to $val]")
444441
end

src/systems/control/controlsystem.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,33 +68,30 @@ struct ControlSystem <: AbstractControlSystem
6868
"""
6969
systems::Vector{ControlSystem}
7070
"""
71-
default_u0: The default initial conditions to use when initial conditions
72-
are not supplied in `ODEProblem`.
71+
defaults: The default values to use when initial conditions and/or
72+
parameters are not supplied in `ODEProblem`.
7373
"""
74-
default_u0::Dict
75-
"""
76-
default_p: The default parameters to use when parameters are not supplied
77-
in `ODEProblem`.
78-
"""
79-
default_p::Dict
74+
defaults::Dict
8075
end
8176

8277
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
8378
observed = [],
8479
systems = ODESystem[],
8580
default_u0=Dict(),
8681
default_p=Dict(),
82+
defaults=_merge(Dict(default_u0), Dict(default_p)),
8783
name=gensym(:ControlSystem))
84+
if !(isempty(default_u0) && isempty(default_p))
85+
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ControlSystem, force=true)
86+
end
8887
iv′ = value(iv)
8988
dvs′ = value.(dvs)
9089
controls′ = value.(controls)
9190
ps′ = value.(ps)
92-
default_u0 isa Dict || (default_u0 = Dict(default_u0))
93-
default_p isa Dict || (default_p = Dict(default_p))
94-
default_u0 = Dict(value(k) => value(default_u0[k]) for k in keys(default_u0))
95-
default_p = Dict(value(k) => value(default_p[k]) for k in keys(default_p))
91+
defaults = todict(defaults)
92+
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
9693
ControlSystem(value(loss), deqs, iv′, dvs′, controls′,
97-
ps′, observed, name, systems, default_u0, default_p)
94+
ps′, observed, name, systems, defaults)
9895
end
9996

10097
struct ControlToExpr

0 commit comments

Comments
 (0)