Skip to content

Commit 5bf0c15

Browse files
WIP: Generalize modelingtoolkitize to handle LabelledArrays
Fixes #1054 Current issue: ```julia using DifferentialEquations using LabelledArrays using ModelingToolkit #using DynamicHMC Random.seed!(42) # ODE model: simple SIR model with seasonally forced contact rate function SIR!(du,u,p,t) # states (S, I, R) = u[1:3] @show typeof(S) N = S + I + R # params β = p.β η = p.η φ = p.φ ω = 1.0/p.ω μ = p.μ σ = p.σ # FOI βeff = β * (1.0+η*cos(2.0*π*(t-φ)/365.0)) λ = βeff*I/N # change in states du[1] = (μ*N - λ*S - μ*S + ω*R) du[2] = (λ*S - σ*I - μ*I) du[3] = (σ*I - μ*R - ω*R) du[4] = (σ*I) # cumulative incidence end # Solver settings tmin = 0.0 tmax = 10.0*365.0 tspan = (tmin, tmax) solvsettings = (abstol = 1.0e-6, reltol = 1.0e-3, saveat = 7.0, solver = AutoTsit5(Rosenbrock23())) # Initiate ODE problem theta_fix = [1.0/(80*365)] theta_est = [0.28, 0.07, 1.0/365.0, 1.0 ,1.0/5.0] p = @larray [theta_est; theta_fix] (:β, :η, :ω, :φ, :σ, :μ) u0 = @larray [9998.0,1.0,1.0,1.0] (:S,:I,:R,:C) # Initiate ODE problem problem = ODEProblem(SIR!,u0,tspan,p) modelingtoolkitize(problem) problem_acc = ODEProblem(modelingtoolkitize(problem), u0, tspan, p, jac=true, sparse=true) ModelingToolkit.define_vars(p,@variables t) sol = solve(problem_acc, solvsettings.solver, abstol=solvsettings.abstol, reltol=solvsettings.reltol, isoutofdomain=(u,p,t)->any(x->x<0.0,u), saveat=solvsettings.saveat) ``` ```julia ArgumentError: The function + cannot be applied to S which is not a Number-like object.Define `islike(::Num, ::Type{Number}) = true` to enable this. assert_like(f::Function, T::Type, a::Num, b::Num) at methods.jl:26 +(a::Num, b::Num) at methods.jl:56 + at operators.jl:560 [inlined] SIR!(du::LArray{Num, 1, Vector{Num}, (:S, :I, :R, :C)}, u::LArray{Num, 1, Vector{Num}, (:S, :I, :R, :C)}, p::LArray{Num, 1, Vector{Num}, (:β, :η, :ω, :φ, :σ, :μ)}, t::Num) at test.jl:14 (::ODEFunction{true, typeof(SIR!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing})(::LArray{Num, 1, Vector{Num}, (:S, :I, :R, :C)}, ::Vararg{Any, N} where N) at scimlfunctions.jl:334 modelingtoolkitize(prob::ODEProblem{LArray{Float64, 1, Vector{Float64}, (:S, :I, :R, :C)}, Tuple{Float64, Float64}, true, LArray{Float64, 1, Vector{Float64}, (:β, :η, :ω, :φ, :σ, :μ)}, ODEFunction{true, typeof(SIR!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}) at modelingtoolkitize.jl:50 top-level scope at test.jl:53 eval at boot.jl:360 [inlined] ```
1 parent 48f87ac commit 5bf0c15

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
1616

1717
has_p = !(p isa Union{DiffEqBase.NullParameters,Nothing})
1818

19-
var(x, i) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
20-
_vars = [var(:x, i)(ModelingToolkit.value(t)) for i in eachindex(prob.u0)]
19+
_vars = define_vars(prob.u0,t)
20+
2121
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0,_vars)
2222
params = if has_p
23-
_params = [Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p)]
24-
p isa Number ? _params[1] : reshape(_params,size(p))
23+
_params = define_params(p)
24+
p isa Number ? _params[1] : ArrayInterface.restructure(p,_params)
2525
else
2626
[]
2727
end
@@ -46,7 +46,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
4646
end
4747

4848
if DiffEqBase.isinplace(prob)
49-
rhs = similar(vars, Num)
49+
rhs = ArrayInterface.restructure(prob.u0,similar(vars, Num))
5050
prob.f(rhs, vars, params, t)
5151
else
5252
rhs = prob.f(vars, params, t)
@@ -71,6 +71,26 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
7171
de
7272
end
7373

74+
_defvaridx(x, i, t) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x, i))))
75+
_defvar(x, t) = Num(Sym{FnType{Tuple{symtype(t)}, Real}}(nameof(Variable(x))))
76+
77+
function define_vars(u,t)
78+
_vars = [_defvaridx(:x, i, t)(ModelingToolkit.value(t)) for i in eachindex(u)]
79+
end
80+
81+
function define_vars(u::Union{SLArray,LArray},t)
82+
_vars = [_defvar(x, t)(ModelingToolkit.value(t)) for x in LabelledArrays.symnames(typeof(u))]
83+
end
84+
85+
function define_params(p)
86+
[Num(toparam(Sym{Real}(nameof(Variable(, i))))) for i in eachindex(p)]
87+
end
88+
89+
function define_params(p::Union{SLArray,LArray})
90+
[Num(toparam(Sym{Real}(nameof(Variable(x))))) for x in LabelledArrays.symnames(typeof(p))]
91+
end
92+
93+
7494
"""
7595
$(TYPEDSIGNATURES)
7696

0 commit comments

Comments
 (0)