Skip to content

Commit fc2f626

Browse files
Merge pull request #1055 from SciML/mtk
Generalize modelingtoolkitize to handle LabelledArrays
2 parents 48f87ac + c61b3de commit fc2f626

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
1616

1717
has_p = !(p isa Union{DiffEqBase.NullParameters,Nothing})
1818

19-
var(x, i) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
20-
_vars = [var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)]
19+
_vars = define_vars(prob.u0,t)
20+
2121
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0,_vars)
2222
params = if has_p
23-
_params = [Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p)]
24-
p isa Number ? _params[1] : reshape(_params,size(p))
23+
_params = define_params(p)
24+
p isa Number ? _params[1] : ArrayInterface.restructure(p,_params)
2525
else
2626
[]
2727
end
@@ -46,7 +46,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
4646
end
4747

4848
if DiffEqBase.isinplace(prob)
49-
rhs = similar(vars, Num)
49+
rhs = ArrayInterface.restructure(prob.u0,similar(vars, Num))
5050
prob.f(rhs, vars, params, t)
5151
else
5252
rhs = prob.f(vars, params, t)
@@ -71,6 +71,26 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
7171
de
7272
end
7373

74+
_defvaridx(x, i, t) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
75+
_defvar(x, t) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x))))
76+
77+
function define_vars(u,t)
78+
_vars = [_defvaridx(:x, i, t)(ModelingToolkit.value(t)) for i in eachindex(u)]
79+
end
80+
81+
function define_vars(u::Union{SLArray,LArray},t)
82+
_vars = [_defvar(x, t)(ModelingToolkit.value(t)) for x in LabelledArrays.symnames(typeof(u))]
83+
end
84+
85+
function define_params(p)
86+
[Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p)]
87+
end
88+
89+
function define_params(p::Union{SLArray,LArray})
90+
[Num(toparam(Sym{Real}(nameof(Variable(x))))) for x in LabelledArrays.symnames(typeof(p))]
91+
end
92+
93+
7494
"""
7595
$(TYPEDSIGNATURES)
7696

test/modelingtoolkitize.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,54 @@ x0 = 1.0
205205
tspan = (0.0,1.0)
206206
prob = ODEProblem(k,x0,tspan)
207207
sys = modelingtoolkitize(prob)
208+
209+
210+
## https://github.com/SciML/ModelingToolkit.jl/issues/1054
211+
using LabelledArrays
212+
using ModelingToolkit
213+
214+
# ODE model: simple SIR model with seasonally forced contact rate
215+
function SIR!(du,u,p,t)
216+
217+
# states
218+
(S, I, R) = u[1:3]
219+
N = S + I + R
220+
221+
# params
222+
β = p.β
223+
η = p.η
224+
φ = p.φ
225+
ω = 1.0/p.ω
226+
μ = p.μ
227+
σ = p.σ
228+
229+
# FOI
230+
βeff = β * (1.0+η*cos(2.0*π*(t-φ)/365.0))
231+
λ = βeff*I/N
232+
233+
# change in states
234+
du[1] =*N - λ*S - μ*S + ω*R)
235+
du[2] =*S - σ*I - μ*I)
236+
du[3] =*I - μ*R - ω*R)
237+
du[4] =*I) # cumulative incidence
238+
239+
end
240+
241+
# Solver settings
242+
tmin = 0.0
243+
tmax = 10.0*365.0
244+
tspan = (tmin, tmax)
245+
246+
# Initiate ODE problem
247+
theta_fix = [1.0/(80*365)]
248+
theta_est = [0.28, 0.07, 1.0/365.0, 1.0 ,1.0/5.0]
249+
p = @LArray [theta_est; theta_fix] (, , , , , )
250+
u0 = @LArray [9998.0,1.0,1.0,1.0] (:S,:I,:R,:C)
251+
252+
# Initiate ODE problem
253+
problem = ODEProblem(SIR!,u0,tspan,p)
254+
sys = modelingtoolkitize(problem)
255+
256+
@parameters t
257+
@test all(isequal.(parameters(sys),getproperty.(@variables(β, η, ω, φ, σ, μ),:val)))
258+
@test all(isequal.(Symbol.(states(sys)),Symbol.(@variables(S(t),I(t),R(t),C(t)))))

0 commit comments

Comments
 (0)