Skip to content

Commit ba38a7a

Browse files
Merge pull request #1671 from SciML/myb/dict
Support parameter dict in modelingtoolkitize
2 parents 2c1ed92 + 325539e commit ba38a7a

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
1717
params = if has_p
1818
_params = define_params(p)
1919
p isa Number ? _params[1] :
20-
(p isa Tuple || p isa NamedTuple ? _params :
20+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
2121
ArrayInterfaceCore.restructure(p, _params))
2222
else
2323
[]
@@ -44,6 +44,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
4444

4545
if DiffEqBase.isinplace(prob)
4646
rhs = ArrayInterfaceCore.restructure(prob.u0, similar(vars, Num))
47+
fill!(rhs, 0)
4748
prob.f(rhs, vars, params, t)
4849
else
4950
rhs = prob.f(vars, params, t)
@@ -53,13 +54,23 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
5354

5455
sts = vec(collect(vars))
5556

57+
_params = params
58+
params = values(params)
5659
params = if params isa Number || (params isa Array && ndims(params) == 0)
5760
[params[1]]
5861
else
5962
vec(collect(params))
6063
end
6164
default_u0 = Dict(sts .=> vec(collect(prob.u0)))
62-
default_p = has_p ? Dict(params .=> vec(collect(prob.p))) : Dict()
65+
default_p = if has_p
66+
if prob.p isa AbstractDict
67+
Dict(v => prob.p[k] for (k, v) in pairs(_params))
68+
else
69+
Dict(params .=> vec(collect(prob.p)))
70+
end
71+
else
72+
Dict()
73+
end
6374

6475
de = ODESystem(eqs, t, sts, params,
6576
defaults = merge(default_u0, default_p);
@@ -69,29 +80,33 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
6980
de
7081
end
7182

72-
_defvaridx(x, i, t) = variable(x, i, T = SymbolicUtils.FnType{Tuple, Real})
73-
_defvar(x, t) = variable(x, T = SymbolicUtils.FnType{Tuple, Real})
83+
_defvaridx(x, i) = variable(x, i, T = SymbolicUtils.FnType{Tuple, Real})
84+
_defvar(x) = variable(x, T = SymbolicUtils.FnType{Tuple, Real})
7485

7586
function define_vars(u, t)
76-
_vars = [_defvaridx(:x, i, t)(t) for i in eachindex(u)]
87+
[_defvaridx(:x, i)(t) for i in eachindex(u)]
7788
end
7889

7990
function define_vars(u::Union{SLArray, LArray}, t)
80-
_vars = [_defvar(x, t)(t) for x in LabelledArrays.symnames(typeof(u))]
91+
[_defvar(x)(t) for x in LabelledArrays.symnames(typeof(u))]
8192
end
8293

8394
function define_vars(u::Tuple, t)
84-
_vars = tuple((_defvaridx(:x, i, t)(ModelingToolkit.value(t)) for i in eachindex(u))...)
95+
tuple((_defvaridx(:x, i)(ModelingToolkit.value(t)) for i in eachindex(u))...)
8596
end
8697

8798
function define_vars(u::NamedTuple, t)
88-
_vars = NamedTuple(x => _defvar(x, t)(ModelingToolkit.value(t)) for x in keys(u))
99+
NamedTuple(x => _defvar(x)(ModelingToolkit.value(t)) for x in keys(u))
89100
end
90101

91102
function define_params(p)
92103
[toparam(variable(, i)) for i in eachindex(p)]
93104
end
94105

106+
function define_params(p::AbstractDict)
107+
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
108+
end
109+
95110
function define_params(p::Union{SLArray, LArray})
96111
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
97112
end

test/modelingtoolkitize.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, ModelingToolkit, Test
1+
using OrdinaryDiffEq, ModelingToolkit, DataStructures, Test
22
using Optimization, RecursiveArrayTools, OptimizationOptimJL
33

44
N = 32
@@ -277,3 +277,15 @@ params = (1, 1)
277277
prob = ODEProblem(ode_prob, [1 1], (0, 1), params)
278278
sys = modelingtoolkitize(prob)
279279
@test nameof.(parameters(sys)) == [:α₁, :α₂]
280+
281+
function ode_prob_dict(du, u, p, t)
282+
du[1] = u[1] + p[:a]
283+
du[2] = u[2] + p[:b]
284+
nothing
285+
end
286+
params = OrderedDict(:a => 10, :b => 20)
287+
u0 = [1, 2.0]
288+
prob = ODEProblem(ode_prob_dict, u0, (0.0, 1.0), params)
289+
sys = modelingtoolkitize(prob)
290+
@test [ModelingToolkit.defaults(sys)[s] for s in states(sys)] == u0
291+
@test [ModelingToolkit.defaults(sys)[s] for s in parameters(sys)] == [10, 20]

0 commit comments

Comments
 (0)