Skip to content

Commit a652789

Browse files
Merge pull request #1162 from SciML/modelingtoolkitize
fix modelingtoolkitize for tuple and namedtuple parameters
2 parents c7b00d6 + 8935b57 commit a652789

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,15 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
88
return prob.f.sys
99
@parameters t
1010

11-
if prob.p isa Tuple || prob.p isa NamedTuple
12-
p = [x for x in prob.p]
13-
else
14-
p = prob.p
15-
end
16-
11+
p = prob.p
1712
has_p = !(p isa Union{DiffEqBase.NullParameters,Nothing})
1813

1914
_vars = define_vars(prob.u0,t)
2015

2116
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0,_vars)
2217
params = if has_p
2318
_params = define_params(p)
24-
p isa Number ? _params[1] : ArrayInterface.restructure(p,_params)
19+
p isa Number ? _params[1] : (p isa Tuple || p isa NamedTuple ? _params : ArrayInterface.restructure(p,_params))
2520
else
2621
[]
2722
end
@@ -55,7 +50,8 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
5550
eqs = vcat([lhs[i] ~ rhs[i] for i in eachindex(prob.u0)]...)
5651

5752
sts = vec(collect(vars))
58-
params = if ndims(params) == 0
53+
54+
params = if params isa Array && ndims(params) == 0
5955
[params[1]]
6056
else
6157
vec(collect(params))
@@ -83,6 +79,14 @@ function define_vars(u::Union{SLArray,LArray},t)
8379
_vars = [_defvar(x, t)(ModelingToolkit.value(t)) for x in LabelledArrays.symnames(typeof(u))]
8480
end
8581

82+
function define_vars(u::Tuple,t)
83+
_vars = tuple((_defvaridx(:x, i, t)(ModelingToolkit.value(t)) for i in eachindex(u))...)
84+
end
85+
86+
function define_vars(u::NamedTuple,t)
87+
_vars = NamedTuple(x=>_defvar(x, t)(ModelingToolkit.value(t)) for x in keys(u))
88+
end
89+
8690
function define_params(p)
8791
[Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p)]
8892
end
@@ -91,6 +95,14 @@ function define_params(p::Union{SLArray,LArray})
9195
[Num(toparam(Sym{Real}(nameof(Variable(x))))) for x in LabelledArrays.symnames(typeof(p))]
9296
end
9397

98+
function define_params(p::Tuple)
99+
tuple((Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p))...)
100+
end
101+
102+
function define_params(p::NamedTuple)
103+
@show NamedTuple(x=>Num(toparam(Sym{Real}(nameof(Variable(x))))) for x in keys(p))
104+
end
105+
94106

95107
"""
96108
$(TYPEDSIGNATURES)
@@ -101,15 +113,19 @@ function modelingtoolkitize(prob::DiffEqBase.SDEProblem; kwargs...)
101113
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
102114
return (prob.f.sys, prob.f.sys.states, prob.f.sys.ps)
103115
@parameters t
104-
if prob.p isa Tuple || prob.p isa NamedTuple
105-
p = [x for x in prob.p]
116+
p = prob.p
117+
has_p = !(p isa Union{DiffEqBase.NullParameters,Nothing})
118+
119+
_vars = define_vars(prob.u0,t)
120+
121+
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0,_vars)
122+
params = if has_p
123+
_params = define_params(p)
124+
p isa Number ? _params[1] : (p isa Tuple || p isa NamedTuple ? _params : ArrayInterface.restructure(p,_params))
106125
else
107-
p = prob.p
126+
[]
108127
end
109-
var(x, i) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
110-
vars = ArrayInterface.restructure(prob.u0,[var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)])
111-
params = p isa DiffEqBase.NullParameters ? [] :
112-
reshape([Num(Sym{Real}(nameof(Variable(, i)))) for i in eachindex(p)],size(p))
128+
113129

114130
D = Differential(t)
115131

test/modelingtoolkitize.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,25 @@ sys = modelingtoolkitize(problem)
256256
@parameters t
257257
@test all(isequal.(parameters(sys),getproperty.(@variables(β, η, ω, φ, σ, μ),:val)))
258258
@test all(isequal.(Symbol.(states(sys)),Symbol.(@variables(S(t),I(t),R(t),C(t)))))
259+
260+
# https://github.com/SciML/ModelingToolkit.jl/issues/1158
261+
262+
function ode_prob(du, u, p::NamedTuple, t)
263+
du[1] = u[1]+p.α*u[2]
264+
du[2] = u[2]+p.β*u[1]
265+
end
266+
params == 1, β = 1)
267+
prob = ODEProblem(ode_prob, [1 1], (0, 1), params)
268+
sys = modelingtoolkitize(prob)
269+
@test nameof.(parameters(sys)) == [,]
270+
271+
function ode_prob(du, u, p::Tuple, t)
272+
α, β = p
273+
du[1] = u[1]+α*u[2]
274+
du[2] = u[2]+β*u[1]
275+
end
276+
277+
params = (1, 1)
278+
prob = ODEProblem(ode_prob, [1 1], (0, 1), params)
279+
sys = modelingtoolkitize(prob)
280+
@test nameof.(parameters(sys)) == [:α₁,:α₂]

0 commit comments

Comments
 (0)