Skip to content

Commit c0fe236

Browse files
authored
Merge pull request #1119 from sharanry/sy/use_difference_for_discreteSys
Use Difference operator for discrete system
2 parents c91e276 + 6d52839 commit c0fe236

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,7 @@ function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys), ps
7878
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
7979
end
8080

81-
@noinline function throw_invalid_derivative(dervar, eq)
82-
msg = "The derivative variable must be isolated to the left-hand " *
83-
"side of the equation like `$dervar ~ ...`.\n Got $eq."
84-
throw(InvalidSystemException(msg))
85-
end
86-
87-
function check_derivative_variables(eq, expr=eq.rhs)
88-
istree(expr) || return nothing
89-
if operation(expr) isa Differential
90-
throw_invalid_derivative(expr, eq)
91-
end
92-
foreach(Base.Fix1(check_derivative_variables, eq), arguments(expr))
93-
end
81+
check_derivative_variables(eq) = check_operator_variables(eq, Differential)
9482

9583
function generate_function(
9684
sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);

src/systems/discrete_system/discrete_system.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,29 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
116116
u = dvs
117117
p = varmap_to_vars(parammap,ps)
118118

119-
f_gen = build_function(rhss, dvs, ps, t; expression=Val{eval_expression}, expression_module=eval_module)
120-
f_oop,f_iip = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
119+
f_gen = generate_function(sys; expression=Val{eval_expression}, expression_module=eval_module)
120+
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
121121
f(u,p,t) = f_oop(u,p,t)
122122
DiscreteProblem(f,u0,tspan,p;kwargs...)
123123
end
124+
125+
isdifference(expr) = istree(expr) && operation(expr) isa Difference
126+
isdifferenceeq(eq) = isdifference(eq.lhs)
127+
128+
check_difference_variables(eq) = check_operator_variables(eq, Difference)
129+
130+
function generate_function(
131+
sys::DiscreteSystem, dvs = states(sys), ps = parameters(sys);
132+
kwargs...
133+
)
134+
eqs = equations(sys)
135+
foreach(check_difference_variables, eqs)
136+
# substitute x(t) by just x
137+
rhss = [eq.rhs for eq in eqs]
138+
139+
u = map(x->time_varying_as_func(value(x), sys), dvs)
140+
p = map(x->time_varying_as_func(value(x), sys), ps)
141+
t = get_iv(sys)
142+
143+
build_function(rhss, u, p, t; kwargs...)
144+
end

src/utils.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,24 @@ function collect_defaults!(defs, vars)
166166
end
167167
return defs
168168
end
169+
170+
"Throw error when difference/derivative operation occurs in the R.H.S."
171+
@noinline function throw_invalid_operator(opvar, eq, op::Type)
172+
if op === Difference
173+
optext = "difference"
174+
elseif op === Differential
175+
optext="derivative"
176+
end
177+
msg = "The $optext variable must be isolated to the left-hand " *
178+
"side of the equation like `$opvar ~ ...`.\n Got $eq."
179+
throw(InvalidSystemException(msg))
180+
end
181+
182+
"Check if difference/derivative operation occurs in the R.H.S. of an equation"
183+
function check_operator_variables(eq, op::Type, expr=eq.rhs)
184+
istree(expr) || return nothing
185+
if operation(expr) isa op
186+
throw_invalid_operator(expr, eq, op)
187+
end
188+
foreach(expr -> check_operator_variables(eq, op, expr), arguments(expr))
189+
end

test/discretesystem.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
- https://github.com/epirecipes/sir-julia/blob/master/markdown/function_map/function_map.md
44
- https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#Deterministic_versus_stochastic_epidemic_models
55
=#
6-
using ModelingToolkit
6+
using ModelingToolkit, Test
77

88
@inline function rate_to_proportion(r,t)
99
1-exp(-r*t)
1010
end;
1111

1212
# Independent and dependent variables and parameters
1313
@parameters t c nsteps δt β γ
14+
D = Difference(t; dt=0.1)
1415
@variables S(t) I(t) R(t) next_S(t) next_I(t) next_R(t)
1516

1617
infection = rate_to_proportion*c*I/(S+I+R),δt)*S
1718
recovery = rate_to_proportion(γ,δt)*I
1819

1920
# Equations
20-
eqs = [next_S ~ S-infection,
21-
next_I ~ I+infection-recovery,
22-
next_R ~ R+recovery]
21+
eqs = [D(S) ~ S-infection,
22+
D(I) ~ I+infection-recovery,
23+
D(R) ~ R+recovery]
2324

2425
# System
2526
sys = DiscreteSystem(eqs,t,[S,I,R],[c,nsteps,δt,β,γ]; controls = [β, γ])
@@ -36,16 +37,16 @@ sol_map = solve(prob_map,FunctionMap());
3637

3738
# Direct Implementation
3839

39-
function sir_map!(du,u,p,t)
40+
function sir_map!(u_diff,u,p,t)
4041
(S,I,R) = u
4142
(β,c,γ,δt) = p
4243
N = S+I+R
4344
infection = rate_to_proportion*c*I/N,δt)*S
4445
recovery = rate_to_proportion(γ,δt)*I
4546
@inbounds begin
46-
du[1] = S-infection
47-
du[2] = I+infection-recovery
48-
du[3] = R+recovery
47+
u_diff[1] = S-infection
48+
u_diff[2] = I+infection-recovery
49+
u_diff[3] = R+recovery
4950
end
5051
nothing
5152
end;

0 commit comments

Comments
 (0)