Skip to content

Commit 2309f9f

Browse files
authored
Merge pull request #2283 from SciML/bgc/split_params_bug
Support heterogeneous parameters for linearize and remake
2 parents 8886a91 + ae05041 commit 2309f9f

File tree

15 files changed

+236
-92
lines changed

15 files changed

+236
-92
lines changed

docs/src/basics/MTKModel_Connector.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ end
130130
`@connector`s accepts begin blocks of `@components`, `@equations`, `@extend`, `@parameters`, `@structural_parameters`, `@variables`. These keywords mean the same as described above for `@mtkmodel`.
131131

132132
!!! note
133+
133134
For more examples of usage, checkout [ModelingToolkitStandardLibrary.jl](https://github.com/SciML/ModelingToolkitStandardLibrary.jl/)
134135

135136
### What's a `structure` dictionary?

src/systems/abstractsystem.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ for prop in [:eqs
229229
:metadata
230230
:gui_metadata
231231
:discrete_subsystems
232-
:unknown_states]
232+
:unknown_states
233+
:split_idxs]
233234
fname1 = Symbol(:get_, prop)
234235
fname2 = Symbol(:has_, prop)
235236
@eval begin
@@ -1273,14 +1274,34 @@ See also [`linearize`](@ref) which provides a higher-level interface.
12731274
function linearization_function(sys::AbstractSystem, inputs,
12741275
outputs; simplify = false,
12751276
initialize = true,
1277+
op = Dict(),
1278+
p = DiffEqBase.NullParameters(),
1279+
zero_dummy_der = false,
12761280
kwargs...)
1277-
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
1281+
ssys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs;
1282+
simplify,
12781283
kwargs...)
1284+
if zero_dummy_der
1285+
dummyder = setdiff(states(ssys), states(sys))
1286+
defs = Dict(x => 0.0 for x in dummyder)
1287+
@set! ssys.defaults = merge(defs, defaults(ssys))
1288+
op = merge(defs, op)
1289+
end
1290+
sys = ssys
1291+
x0 = merge(defaults(sys), op)
1292+
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1293+
p, split_idxs = split_parameters_by_type(p)
1294+
ps = parameters(sys)
1295+
if p isa Tuple
1296+
ps = Base.Fix1(getindex, ps).(split_idxs)
1297+
ps = (ps...,) #if p is Tuple, ps should be Tuple
1298+
end
1299+
12791300
lin_fun = let diff_idxs = diff_idxs,
12801301
alge_idxs = alge_idxs,
12811302
input_idxs = input_idxs,
12821303
sts = states(sys),
1283-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
1304+
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, states(sys), ps; p = p),
12841305
h = build_explicit_observed_function(sys, outputs),
12851306
chunk = ForwardDiff.Chunk(input_idxs)
12861307

@@ -1599,11 +1620,12 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
15991620
allow_input_derivatives = false,
16001621
zero_dummy_der = false,
16011622
kwargs...)
1602-
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
1603-
if zero_dummy_der
1604-
dummyder = setdiff(states(ssys), states(sys))
1605-
op = merge(op, Dict(x => 0.0 for x in dummyder))
1606-
end
1623+
lin_fun, ssys = linearization_function(sys,
1624+
inputs,
1625+
outputs;
1626+
zero_dummy_der,
1627+
op,
1628+
kwargs...)
16071629
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
16081630
end
16091631

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
354354
checkbounds = false,
355355
sparsity = false,
356356
analytic = nothing,
357+
split_idxs = nothing,
357358
kwargs...) where {iip, specialize}
358359
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
359360
expression_module = eval_module, checkbounds = checkbounds,
@@ -508,6 +509,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
508509
nothing
509510
end
510511

512+
@set! sys.split_idxs = split_idxs
511513
ODEFunction{iip, specialize}(f;
512514
sys = sys,
513515
jac = _jac === nothing ? nothing : _jac,
@@ -765,15 +767,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
765767
"""
766768
function get_u0_p(sys,
767769
u0map,
768-
parammap;
770+
parammap = nothing;
769771
use_union = true,
770772
tofloat = true,
771773
symbolic_u0 = false)
772774
dvs = states(sys)
773775
ps = parameters(sys)
774776

775777
defs = defaults(sys)
776-
defs = mergedefaults(defs, parammap, ps)
778+
if parammap !== nothing
779+
defs = mergedefaults(defs, parammap, ps)
780+
end
777781
defs = mergedefaults(defs, u0map, dvs)
778782

779783
if symbolic_u0
@@ -835,7 +839,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
835839
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
836840
checkbounds = checkbounds, p = p,
837841
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
838-
sparse = sparse, eval_expression = eval_expression,
842+
sparse = sparse, eval_expression = eval_expression, split_idxs,
839843
kwargs...)
840844
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
841845
end

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,19 @@ struct ODESystem <: AbstractODESystem
139139
used for ODAEProblem.
140140
"""
141141
unknown_states::Union{Nothing, Vector{Any}}
142+
"""
143+
split_idxs: a vector of vectors of indices for the split parameters.
144+
"""
145+
split_idxs::Union{Nothing, Vector{Vector{Int}}}
142146

143147
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
144148
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
145149
torn_matching, connector_type, preface, cevents,
146150
devents, metadata = nothing, gui_metadata = nothing,
147151
tearing_state = nothing,
148152
substitutions = nothing, complete = false,
149-
discrete_subsystems = nothing, unknown_states = nothing;
150-
checks::Union{Bool, Int} = true)
153+
discrete_subsystems = nothing, unknown_states = nothing,
154+
split_idxs = nothing; checks::Union{Bool, Int} = true)
151155
if checks == true || (checks & CheckComponents) > 0
152156
check_variables(dvs, iv)
153157
check_parameters(ps, iv)
@@ -161,7 +165,7 @@ struct ODESystem <: AbstractODESystem
161165
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
162166
connector_type, preface, cevents, devents, metadata, gui_metadata,
163167
tearing_state, substitutions, complete, discrete_subsystems,
164-
unknown_states)
168+
unknown_states, split_idxs)
165169
end
166170
end
167171

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
293293
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
294294
parammap = DiffEqBase.NullParameters();
295295
checkbounds = false,
296-
use_union = false,
296+
use_union = true,
297297
kwargs...)
298298
dvs = states(sys)
299299
ps = parameters(sys)

src/utils.jl

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -657,46 +657,40 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
657657
end
658658
T = eltype(vs)
659659
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
660-
vs
660+
return vs
661661
else
662662
sym_vs = filter(x -> SymbolicUtils.issym(x) || SymbolicUtils.istree(x), vs)
663663
isempty(sym_vs) || throw_missingvars_in_sys(sym_vs)
664-
C = typeof(first(vs))
665-
I = Int8
666-
has_int = false
667-
has_array = false
668-
has_bool = false
669-
array_T = nothing
664+
665+
C = nothing
670666
for v in vs
671-
if v isa AbstractArray
672-
has_array = true
673-
array_T = typeof(v)
667+
E = typeof(v)
668+
if E <: Number
669+
if tofloat
670+
E = float(E)
671+
end
674672
end
675-
E = eltype(v)
676-
C = promote_type(C, E)
677-
if E <: Integer
678-
has_int = true
679-
I = promote_type(I, E)
673+
if C === nothing
674+
C = E
680675
end
681-
if E <: Bool
682-
has_bool = true
676+
if use_union
677+
C = Union{C, E}
678+
else
679+
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
680+
C = E
683681
end
684682
end
685-
if tofloat && !has_array
686-
C = float(C)
687-
elseif has_array || (use_union && has_int && C !== I)
688-
if has_array
689-
C = Union{C, array_T}
690-
end
691-
if has_int
692-
C = Union{C, I}
693-
end
694-
if has_bool
695-
C = Union{C, Bool}
683+
684+
y = similar(vs, C)
685+
for i in eachindex(vs)
686+
if (vs[i] isa Number) & tofloat
687+
y[i] = float(vs[i]) #needed because copyto! can't convert Int to Float automatically
688+
else
689+
y[i] = vs[i]
696690
end
697-
return copyto!(similar(vs, C), vs)
698691
end
699-
convert.(C, vs)
692+
693+
return y
700694
end
701695
end
702696

src/variables.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,23 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
145145
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
146146
end
147147

148-
# assemble defaults
149-
defs = defaults(prob.f.sys)
150-
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
151-
defs = mergedefaults(defs, p, parameters(prob.f.sys))
152-
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
153-
defs = mergedefaults(defs, u0, states(prob.f.sys))
154-
155-
u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
156-
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
148+
sys = prob.f.sys
149+
defs = defaults(sys)
150+
ps = parameters(sys)
151+
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
152+
for (i, idxs) in enumerate(split_idxs)
153+
defs = mergedefaults(defs, prob.p[i], ps[idxs])
154+
end
155+
else
156+
# assemble defaults
157+
defs = defaults(sys)
158+
defs = mergedefaults(defs, prob.p, ps)
159+
end
160+
defs = mergedefaults(defs, p, ps)
161+
sts = states(sys)
162+
defs = mergedefaults(defs, prob.u0, sts)
163+
defs = mergedefaults(defs, u0, sts)
164+
u0, p, defs = get_u0_p(sys, defs)
157165

158166
return p, u0
159167
end

test/input_output_handling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ using ModelingToolkitStandardLibrary.Mechanical.Rotational
124124
t = ModelingToolkitStandardLibrary.Mechanical.Rotational.t
125125
@named inertia1 = Inertia(; J = 1)
126126
@named inertia2 = Inertia(; J = 1)
127-
@named spring = Spring(; c = 10)
128-
@named damper = Damper(; d = 3)
127+
@named spring = Rotational.Spring(; c = 10)
128+
@named damper = Rotational.Damper(; d = 3)
129129
@named torque = Torque(; use_support = false)
130130
@variables y(t) = 0
131131
eqs = [connect(torque.flange, inertia1.flange_a)

test/latexify/10.tex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
\begin{align}
2-
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\sigma \left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( - y\left( t \right) + x\left( t \right) \right)}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
3-
0 =& - y\left( t \right) + \frac{1}{10} \sigma \left( \rho - z\left( t \right) \right) x\left( t \right) \\
4-
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - \beta z\left( t \right)
2+
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( x\left( t \right) - y\left( t \right) \right) \sigma}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
3+
0 =& - y\left( t \right) + \frac{1}{10} x\left( t \right) \left( - z\left( t \right) + \rho \right) \sigma \\
4+
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - z\left( t \right) \beta
55
\end{align}

test/latexify/20.tex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
\begin{align}
2-
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
3-
0 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
2+
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& p_3 \left( - u(t)_1 + u(t)_2 \right) \\
3+
0 =& - u(t)_2 + \frac{1}{10} \left( p_1 - u(t)_1 \right) p_2 p_3 u(t)_1 \\
44
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_3 =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
55
\end{align}

0 commit comments

Comments
 (0)