Skip to content

Commit dd64720

Browse files
authored
Linearizing delayed equation systems (#1151)
1 parent c3b902d commit dd64720

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

src/systems/discrete_system/discrete_system.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
113113
dvs = states(sys)
114114
ps = parameters(sys)
115115
eqs = equations(sys)
116+
eqs = linearize_eqs(sys, eqs)
116117
# defs = defaults(sys)
117118
t = get_iv(sys)
118119
u0 = varmap_to_vars(u0map,dvs)
@@ -126,6 +127,55 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
126127
DiscreteProblem(f,u0,tspan,p;kwargs...)
127128
end
128129

130+
function linearize_eqs(sys, eqs=get_eqs(sys); return_max_delay=false)
131+
unique_states = unique(operation.(states(sys)))
132+
max_delay = Dict(v=>0.0 for v in unique_states)
133+
134+
r = @rule ~t::(t -> istree(t) && any(isequal(operation(t)), operation.(states(sys))) && is_delay_var(get_iv(sys), t)) => begin
135+
delay = get_delay_val(get_iv(sys), first(arguments(~t)))
136+
if delay > max_delay[operation(~t)]
137+
max_delay[operation(~t)] = delay
138+
end
139+
nothing
140+
end
141+
SymbolicUtils.Postwalk(r).(rhss(eqs))
142+
143+
if any(values(max_delay) .> 0)
144+
145+
dts = Dict(v=>Any[] for v in unique_states)
146+
state_ops = Dict(v=>Any[] for v in unique_states)
147+
for v in unique_states
148+
for eq in eqs
149+
if isdifferenceeq(eq) && istree(arguments(eq.lhs)[1]) && isequal(v, operation(arguments(eq.lhs)[1]))
150+
append!(dts[v], [operation(eq.lhs).dt])
151+
append!(state_ops[v], [operation(eq.lhs)])
152+
end
153+
end
154+
end
155+
156+
all(length.(unique.(values(state_ops))) .<= 1) || error("Each state should be used with single difference operator.")
157+
158+
dts_gcd = Dict()
159+
for v in keys(dts)
160+
dts_gcd[v] = (length(dts[v]) > 0) ? first(dts[v]) : nothing
161+
end
162+
163+
lin_eqs = [
164+
v(get_iv(sys) - (t)) ~ v(get_iv(sys) - (t-dts_gcd[v]))
165+
for v in unique_states if max_delay[v] > 0 && dts_gcd[v]!==nothing for t in collect(max_delay[v]:(-dts_gcd[v]):0)[1:end-1]
166+
]
167+
eqs = vcat(eqs, lin_eqs)
168+
end
169+
if return_max_delay return eqs, max_delay end
170+
eqs
171+
end
172+
173+
function get_delay_val(iv, x)
174+
delay = x - iv
175+
isequal(delay > 0, true) && error("Forward delay not permitted")
176+
return -delay
177+
end
178+
129179
check_difference_variables(eq) = check_operator_variables(eq, Difference)
130180

131181
function generate_function(

src/utils.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,25 @@ function check_parameters(ps, iv)
105105
end
106106
end
107107

108+
function is_delay_var(iv, var)
109+
args = nothing
110+
try
111+
args = arguments(var)
112+
catch
113+
return false
114+
end
115+
length(args) > 1 && return false
116+
isequal(first(args), iv) && return false
117+
delay = iv - first(args)
118+
delay isa Integer ||
119+
delay isa AbstractFloat ||
120+
(delay isa Num && isreal(value(delay)))
121+
end
122+
108123
function check_variables(dvs, iv)
109124
for dv in dvs
110125
isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
111-
occursin(iv, iv_from_nested_derivative(dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
126+
(is_delay_var(iv, dv) || occursin(iv, iv_from_nested_derivative(dv))) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
112127
end
113128
end
114129

test/discretesystem.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,37 @@ prob_map = DiscreteProblem(sir_map!,u0,tspan,p);
5656
sol_map2 = solve(prob_map,FunctionMap());
5757

5858
@test Array(sol_map) Array(sol_map2)
59+
60+
# Delayed difference equation
61+
@parameters t
62+
@variables x(..) y(..) z(t)
63+
D1 = Difference(t; dt=1.5)
64+
D2 = Difference(t; dt=2)
65+
66+
@test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(x(t-2)))
67+
@test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(y(t-1)))
68+
@test !ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(z))
69+
@test_throws ErrorException ModelingToolkit.get_delay_val(Symbolics.value(t), Symbolics.arguments(Symbolics.value(x(t+2)))[1])
70+
@test_throws ErrorException z(t)
71+
72+
# Equations
73+
eqs = [
74+
D1(x(t)) ~ 0.4x(t) + 0.3x(t-1.5) + 0.1x(t-3),
75+
D2(y(t)) ~ 0.3y(t) + 0.7y(t-2) + 0.1z,
76+
]
77+
78+
# System
79+
@named sys = DiscreteSystem(eqs,t,[x(t),x(t-1.5),x(t-3),y(t),y(t-2),z],[])
80+
81+
eqs2, max_delay = ModelingToolkit.linearize_eqs(sys; return_max_delay=true)
82+
83+
@test max_delay[Symbolics.operation(Symbolics.value(x(t)))] 3
84+
@test max_delay[Symbolics.operation(Symbolics.value(y(t)))] 2
85+
86+
linearized_eqs = [
87+
eqs
88+
x(t - 3.0) ~ x(t - 1.5)
89+
x(t - 1.5) ~ x(t)
90+
y(t - 2.0) ~ y(t)
91+
]
92+
@test all(eqs2 .== linearized_eqs)

0 commit comments

Comments
 (0)