Skip to content

Commit d10d5e1

Browse files
committed
Support constants in jump systems.
1 parent 7d11edb commit d10d5e1

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,23 @@ function JumpSystem(eqs, iv, states, ps;
157157
end
158158

159159
function generate_rate_function(js::JumpSystem, rate)
160+
consts = collect_constants(rate)
161+
if !isempty(consts) # The SymbolicUtils._build_function method of this case doesn't support preprocessing
162+
csubs = Dict(c => getdefault(c) for c in consts)
163+
rate = substitute(rate, csubs)
164+
end
160165
rf = build_function(rate, states(js), parameters(js),
161166
get_iv(js),
162167
conv = states_to_sym(states(js)),
163168
expression = Val{true})
164169
end
165170

166171
function generate_affect_function(js::JumpSystem, affect, outputidxs)
172+
consts = collect_constants(affect)
173+
if !isempty(consts) # The SymbolicUtils._build_function method of this case doesn't support preprocessing
174+
csubs = Dict(c => getdefault(c) for c in consts)
175+
affect = substitute(affect, csubs)
176+
end
167177
compile_affect(affect, js, states(js), parameters(js); outputidxs = outputidxs,
168178
expression = Val{true}, checkvars = false)
169179
end

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,13 @@ function collect_constants(eqs::AbstractArray{T}) where {T} # For generate_tgrad
521521
return constants
522522
end
523523

524+
collect_constants(x::Num) = collect_constants(unwrap(x))
525+
function collect_constants(expr::Symbolic{T}) where {T} # For jump system affect / rate
526+
constants = Symbolic[]
527+
collect_constants!(constants,expr)
528+
return constants
529+
end
530+
524531
function collect_constant!(constants, var)
525532
if isconstant(var)
526533
push!(constants, var)

test/discretesystem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ end;
1111

1212
# Independent and dependent variables and parameters
1313
@parameters t c nsteps δt β γ
14+
@constants h = 1
1415
D = Difference(t; dt = 0.1)
1516
@variables S(t) I(t) R(t)
1617

17-
infection = rate_to_proportion* c * I / (S + I + R), δt) * S
18-
recovery = rate_to_proportion(γ, δt) * I
18+
infection = rate_to_proportion* c * I / (S * h + I + R), δt * h) * S
19+
recovery = rate_to_proportion * h, δt) * I
1920

2021
# Equations
21-
eqs = [D(S) ~ S - infection,
22+
eqs = [D(S) ~ S - infection * h,
2223
D(I) ~ I + infection - recovery,
2324
D(R) ~ R + recovery]
2425

@@ -99,7 +100,7 @@ D2 = Difference(t; dt = 2)
99100
# Equations
100101
eqs = [
101102
D1(x(t)) ~ 0.4x(t) + 0.3x(t - 1.5) + 0.1x(t - 3),
102-
D2(y(t)) ~ 0.3y(t) + 0.7y(t - 2) + 0.1z,
103+
D2(y(t)) ~ 0.3y(t) + 0.7y(t - 2) + 0.1z * h,
103104
]
104105

105106
# System
@@ -119,7 +120,7 @@ linearized_eqs = [eqs
119120
# observed variable handling
120121
@variables t x(t) RHS(t)
121122
@parameters τ
122-
@named fol = DiscreteSystem([D(x) ~ (1 - x) / τ]; observed = [RHS ~ (1 - x) / τ])
123+
@named fol = DiscreteSystem([D(x) ~ (1 - x) / τ]; observed = [RHS ~ (1 - x) / τ * h])
123124
@test isequal(RHS, @nonamespace fol.RHS)
124125
RHS2 = RHS
125126
@unpack RHS = fol

test/jumpsystem.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ MT = ModelingToolkit
33

44
# basic MT SIR model with tweaks
55
@parameters β γ t
6+
@constants h=1
67
@variables S(t) I(t) R(t)
7-
rate₁ = β * S * I
8-
affect₁ = [S ~ S - 1, I ~ I + 1]
8+
rate₁ = β * S * I * h
9+
affect₁ = [S ~ S - 1 * h, I ~ I + 1]
910
rate₂ = γ * I + t
1011
affect₂ = [I ~ I - 1, R ~ R + 1]
1112
j₁ = ConstantRateJump(rate₁, affect₁)
@@ -52,8 +53,8 @@ jump2.affect!(integrator)
5253
@test all(integrator.u .== mtintegrator.u)
5354

5455
# test MT can make and solve a jump problem
55-
rate₃ = γ * I
56-
affect₃ = [I ~ I - 1, R ~ R + 1]
56+
rate₃ = γ * I * h
57+
affect₃ = [I ~ I * h - 1, R ~ R + 1]
5758
j₃ = ConstantRateJump(rate₃, affect₃)
5859
@named js2 = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ])
5960
u₀ = [999, 1, 0];
@@ -83,6 +84,7 @@ sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
8384
@test all(2 .* sol[S] .== sol[S2])
8485

8586
# test save_positions is working
87+
8688
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false))
8789
sol = solve(jprob, SSAStepper(), saveat = 1.0)
8890
@test all((sol.t) .== collect(0.0:tspan[2]))
@@ -192,7 +194,7 @@ sol = solve(jprob, SSAStepper(), tstops = [1000.0],
192194

193195
# observed variable handling
194196
@variables OBS(t)
195-
@named js5 = JumpSystem([maj1, maj2], t, [S], [β, γ]; observed = [OBS ~ 2 * S])
197+
@named js5 = JumpSystem([maj1, maj2], t, [S], [β, γ]; observed = [OBS ~ 2 * S * h])
196198
OBS2 = OBS
197199
@test isequal(OBS2, @nonamespace js5.OBS)
198200
@unpack OBS = js5

0 commit comments

Comments
 (0)