Skip to content

Commit 5e666b5

Browse files
Merge pull request #677 from SciML/mtkitize_lowering
fix variable name lowering for non-map parameters and SIR regression
2 parents 23fc8fb + 8280935 commit 5e666b5

File tree

4 files changed

+122
-9
lines changed

4 files changed

+122
-9
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,16 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
267267
kwargs...) where iip
268268
dvs = states(sys)
269269
ps = parameters(sys)
270-
u0map′ = [lower_varname(value(k), sys.iv) => value(v) for (k, v) in u0map]
270+
u0map′ = lower_mapnames(u0map,sys.iv)
271271
u0 = varmap_to_vars(u0map′,dvs)
272+
272273
if !(parammap isa DiffEqBase.NullParameters)
273-
parammap′ = [value(k) => value(v) for (k, v) in parammap]
274+
parammap′ = lower_mapnames(parammap)
274275
p = varmap_to_vars(parammap′,ps)
275276
else
276277
p = ps
277278
end
279+
278280
f = ODEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
279281
linenumbers=linenumbers,parallel=parallel,simplify=simplify,
280282
sparse=sparse,eval_expression=eval_expression,kwargs...)
@@ -311,14 +313,16 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem,u0map,tspan,
311313

312314
dvs = states(sys)
313315
ps = parameters(sys)
314-
u0map′ = [lower_varname(value(k), sys.iv) => value(v) for (k, v) in u0map]
315-
parammap′ = [value(k) => value(v) for (k, v) in parammap]
316+
u0map′ = lower_mapnames(u0map,sys.iv)
316317
u0 = varmap_to_vars(u0map′,dvs)
318+
317319
if !(parammap isa DiffEqBase.NullParameters)
320+
parammap′ = lower_mapnames(parammap)
318321
p = varmap_to_vars(parammap′,ps)
319322
else
320323
p = ps
321324
end
325+
322326
f = ODEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,checkbounds=checkbounds,
323327
linenumbers=linenumbers,parallel=parallel,
324328
simplify=simplify,

src/systems/diffeqs/sdesystem.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,17 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
310310

311311
dvs = states(sys)
312312
ps = parameters(sys)
313-
u0 = varmap_to_vars(u0map,dvs)
314-
p = varmap_to_vars(parammap,ps)
313+
314+
u0map′ = lower_mapnames(u0map,sys.iv)
315+
u0 = varmap_to_vars(u0map′,dvs)
316+
317+
if !(parammap isa DiffEqBase.NullParameters)
318+
parammap′ = lower_mapnames(parammap)
319+
p = varmap_to_vars(parammap′,ps)
320+
else
321+
p = ps
322+
end
323+
315324
f = SDEFunction{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,Wfact=Wfact,
316325
checkbounds=checkbounds,
317326
linenumbers=linenumbers,parallel=parallel,
@@ -358,8 +367,17 @@ function SDEProblemExpr{iip}(sys::SDESystem,u0map,tspan,
358367
kwargs...) where iip
359368
dvs = states(sys)
360369
ps = parameters(sys)
361-
u0 = varmap_to_vars(u0map,dvs)
362-
p = varmap_to_vars(parammap,ps)
370+
371+
u0map′ = lower_mapnames(u0map,sys.iv)
372+
u0 = varmap_to_vars(u0map′,dvs)
373+
374+
if !(parammap isa DiffEqBase.NullParameters)
375+
parammap′ = lower_mapnames(parammap)
376+
p = varmap_to_vars(parammap′,ps)
377+
else
378+
p = ps
379+
end
380+
363381
f = SDEFunctionExpr{iip}(sys,dvs,ps,u0;tgrad=tgrad,jac=jac,
364382
Wfact=Wfact,checkbounds=checkbounds,
365383
linenumbers=linenumbers,parallel=parallel,

src/utils.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,15 @@ function lower_varname(t::Term, iv)
271271
end
272272
lower_varname(t::Sym, iv) = t
273273

274+
function lower_mapnames(umap::AbstractArray{<:Pair}) where T
275+
[value(k) => value(v) for (k, v) in umap]
276+
end
277+
function lower_mapnames(umap::AbstractArray{<:Pair},name) where T
278+
[lower_varname(value(k), name) => value(v) for (k, v) in umap]
279+
end
280+
lower_mapnames(umap::AbstractArray{<:Number}) = umap # Ambiguity
281+
lower_mapnames(umap::AbstractArray{<:Number},name) = umap
282+
274283
function flatten_differential(O::Term)
275284
@assert is_derivative(O) "invalid differential: $O"
276285
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)

test/modelingtoolkitize.jl

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using OrdinaryDiffEq, ModelingToolkit, Test
22
using GalacticOptim, Optim
33

4-
const N = 32
4+
N = 32
55
const xyd_brusselator = range(0,stop=1,length=N)
66
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
77
limit(a, N) = ModelingToolkit.ifelse(a == N+1, 1, ModelingToolkit.ifelse(a == 0, N, a))
@@ -59,3 +59,85 @@ sol = solve(prob,BFGS())
5959

6060
sol = solve(prob,Newton())
6161
@test sol.minimum < 1e-8
62+
63+
## SIR System Regression Test
64+
65+
β = 0.01# infection rate
66+
λ_R = 0.05 # inverse of transition time from infected to recovered
67+
λ_D = 0.83 # inverse of transition time from infected to dead
68+
i₀ = 0.075 # fraction of initial infected people in every age class
69+
𝒫 = vcat([β, λ_R, λ_D]...)
70+
71+
# regional contact matrix and regional population
72+
73+
## regional contact matrix
74+
regional_all_contact_matrix = [3.45536 0.485314 0.506389 0.123002 ; 0.597721 2.11738 0.911374 0.323385 ; 0.906231 1.35041 1.60756 0.67411 ; 0.237902 0.432631 0.726488 0.979258] # 4x4 contact matrix
75+
76+
## regional population stratified by age
77+
N = [723208 , 874150, 1330993, 1411928] # array of 4 elements, each of which representing the absolute amount of population in the corresponding age class.
78+
79+
80+
# Initial conditions
81+
I₀ = repeat([i₀],4)
82+
S₀ = N.-I₀
83+
R₀ = [0.0 for n in 1:length(N)]
84+
D₀ = [0.0 for n in 1:length(N)]
85+
D_tot₀ = [0.0 for n in 1:length(N)]
86+
= vcat([S₀, I₀, R₀, D₀, D_tot₀]...)
87+
88+
# Time
89+
final_time = 20
90+
𝒯 = (1.0,final_time);
91+
92+
93+
94+
95+
function SIRD_ac!(du,u,p,t)
96+
# Parameters to be calibrated
97+
β, λ_R, λ_D = p
98+
99+
# initialize this parameter (death probability stratified by age, taken from literature)
100+
101+
δ₁, δ₂, δ₃, δ₄ = [0.003/100, 0.004/100, (0.015+0.030+0.064+0.213+0.718)/(5*100), (2.384+8.466+12.497+1.117)/(4*100)]
102+
δ = vcat(repeat([δ₁],1),repeat([δ₂],1),repeat([δ₃],1),repeat([δ₄],4-1-1-1))
103+
104+
105+
C = regional_all_contact_matrix
106+
107+
108+
# State variables
109+
S = @view u[4*0+1:4*1]
110+
I = @view u[4*1+1:4*2]
111+
R = @view u[4*2+1:4*3]
112+
D = @view u[4*3+1:4*4]
113+
D_tot = @view u[4*4+1:4*5]
114+
115+
# Differentials
116+
dS = @view du[4*0+1:4*1]
117+
dI = @view du[4*1+1:4*2]
118+
dR = @view du[4*2+1:4*3]
119+
dD = @view du[4*3+1:4*4]
120+
dD_tot = @view du[4*4+1:4*5]
121+
122+
# Force of infection
123+
Λ = β*[sum([C[i,j]*I[j]/N[j] for j in 1:size(C)[1]]) for i in 1:size(C)[2]]
124+
125+
# System of equations
126+
@. dS = -Λ*S
127+
@. dI = Λ*S - ((1-δ)*λ_R + δ*λ_D)*I
128+
@. dR = λ_R*(1-δ)*I
129+
@. dD = λ_D*δ*I
130+
@. dD_tot = dD[1]+dD[2]+dD[3]+dD[4]
131+
132+
133+
end;
134+
135+
136+
# create problem and check it works
137+
problem = ODEProblem(SIRD_ac!, ℬ, 𝒯, 𝒫)
138+
@time solution = solve(problem, Tsit5(), saveat = 1:final_time);
139+
140+
problem = ODEProblem(SIRD_ac!, ℬ, 𝒯, 𝒫)
141+
sys = modelingtoolkitize(problem)
142+
fast_problem = ODEProblem(sys,ℬ, 𝒯, 𝒫 )
143+
@time solution = solve(fast_problem, Tsit5(), saveat = 1:final_time)

0 commit comments

Comments
 (0)