Skip to content

Commit eb6357c

Browse files
committed
Support parameter dict in modelingtoolkitize
1 parent 2c1ed92 commit eb6357c

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
1717
params = if has_p
1818
_params = define_params(p)
1919
p isa Number ? _params[1] :
20-
(p isa Tuple || p isa NamedTuple ? _params :
20+
(p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
2121
ArrayInterfaceCore.restructure(p, _params))
2222
else
2323
[]
@@ -44,6 +44,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
4444

4545
if DiffEqBase.isinplace(prob)
4646
rhs = ArrayInterfaceCore.restructure(prob.u0, similar(vars, Num))
47+
fill!(rhs, 0)
4748
prob.f(rhs, vars, params, t)
4849
else
4950
rhs = prob.f(vars, params, t)
@@ -53,6 +54,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
5354

5455
sts = vec(collect(vars))
5556

57+
params = values(params)
5658
params = if params isa Number || (params isa Array && ndims(params) == 0)
5759
[params[1]]
5860
else
@@ -69,29 +71,33 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
6971
de
7072
end
7173

72-
_defvaridx(x, i, t) = variable(x, i, T = SymbolicUtils.FnType{Tuple, Real})
73-
_defvar(x, t) = variable(x, T = SymbolicUtils.FnType{Tuple, Real})
74+
_defvaridx(x, i) = variable(x, i, T = SymbolicUtils.FnType{Tuple, Real})
75+
_defvar(x) = variable(x, T = SymbolicUtils.FnType{Tuple, Real})
7476

7577
function define_vars(u, t)
76-
_vars = [_defvaridx(:x, i, t)(t) for i in eachindex(u)]
78+
[_defvaridx(:x, i)(t) for i in eachindex(u)]
7779
end
7880

7981
function define_vars(u::Union{SLArray, LArray}, t)
80-
_vars = [_defvar(x, t)(t) for x in LabelledArrays.symnames(typeof(u))]
82+
[_defvar(x)(t) for x in LabelledArrays.symnames(typeof(u))]
8183
end
8284

8385
function define_vars(u::Tuple, t)
84-
_vars = tuple((_defvaridx(:x, i, t)(ModelingToolkit.value(t)) for i in eachindex(u))...)
86+
tuple((_defvaridx(:x, i)(ModelingToolkit.value(t)) for i in eachindex(u))...)
8587
end
8688

8789
function define_vars(u::NamedTuple, t)
88-
_vars = NamedTuple(x => _defvar(x, t)(ModelingToolkit.value(t)) for x in keys(u))
90+
NamedTuple(x => _defvar(x)(ModelingToolkit.value(t)) for x in keys(u))
8991
end
9092

9193
function define_params(p)
9294
[toparam(variable(, i)) for i in eachindex(p)]
9395
end
9496

97+
function define_params(p::AbstractDict)
98+
Dict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
99+
end
100+
95101
function define_params(p::Union{SLArray, LArray})
96102
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
97103
end

0 commit comments

Comments
 (0)