Skip to content

Commit f4f8302

Browse files
author
Lucas Morton
committed
Merge branch 'master' into uniting
2 parents c0823f5 + dd64720 commit f4f8302

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

src/systems/discrete_system/discrete_system.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
116116
dvs = states(sys)
117117
ps = parameters(sys)
118118
eqs = equations(sys)
119+
eqs = linearize_eqs(sys, eqs)
119120
# defs = defaults(sys)
120121
t = get_iv(sys)
121122
u0 = varmap_to_vars(u0map,dvs)
@@ -129,6 +130,55 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
129130
DiscreteProblem(f,u0,tspan,p;kwargs...)
130131
end
131132

133+
function linearize_eqs(sys, eqs=get_eqs(sys); return_max_delay=false)
134+
unique_states = unique(operation.(states(sys)))
135+
max_delay = Dict(v=>0.0 for v in unique_states)
136+
137+
r = @rule ~t::(t -> istree(t) && any(isequal(operation(t)), operation.(states(sys))) && is_delay_var(get_iv(sys), t)) => begin
138+
delay = get_delay_val(get_iv(sys), first(arguments(~t)))
139+
if delay > max_delay[operation(~t)]
140+
max_delay[operation(~t)] = delay
141+
end
142+
nothing
143+
end
144+
SymbolicUtils.Postwalk(r).(rhss(eqs))
145+
146+
if any(values(max_delay) .> 0)
147+
148+
dts = Dict(v=>Any[] for v in unique_states)
149+
state_ops = Dict(v=>Any[] for v in unique_states)
150+
for v in unique_states
151+
for eq in eqs
152+
if isdifferenceeq(eq) && istree(arguments(eq.lhs)[1]) && isequal(v, operation(arguments(eq.lhs)[1]))
153+
append!(dts[v], [operation(eq.lhs).dt])
154+
append!(state_ops[v], [operation(eq.lhs)])
155+
end
156+
end
157+
end
158+
159+
all(length.(unique.(values(state_ops))) .<= 1) || error("Each state should be used with single difference operator.")
160+
161+
dts_gcd = Dict()
162+
for v in keys(dts)
163+
dts_gcd[v] = (length(dts[v]) > 0) ? first(dts[v]) : nothing
164+
end
165+
166+
lin_eqs = [
167+
v(get_iv(sys) - (t)) ~ v(get_iv(sys) - (t-dts_gcd[v]))
168+
for v in unique_states if max_delay[v] > 0 && dts_gcd[v]!==nothing for t in collect(max_delay[v]:(-dts_gcd[v]):0)[1:end-1]
169+
]
170+
eqs = vcat(eqs, lin_eqs)
171+
end
172+
if return_max_delay return eqs, max_delay end
173+
eqs
174+
end
175+
176+
function get_delay_val(iv, x)
177+
delay = x - iv
178+
isequal(delay > 0, true) && error("Forward delay not permitted")
179+
return -delay
180+
end
181+
132182
check_difference_variables(eq) = check_operator_variables(eq, Difference)
133183

134184
function generate_function(

src/utils.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,25 @@ function check_parameters(ps, iv)
105105
end
106106
end
107107

108+
function is_delay_var(iv, var)
109+
args = nothing
110+
try
111+
args = arguments(var)
112+
catch
113+
return false
114+
end
115+
length(args) > 1 && return false
116+
isequal(first(args), iv) && return false
117+
delay = iv - first(args)
118+
delay isa Integer ||
119+
delay isa AbstractFloat ||
120+
(delay isa Num && isreal(value(delay)))
121+
end
122+
108123
function check_variables(dvs, iv)
109124
for dv in dvs
110125
isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables."))
111-
occursin(iv, iv_from_nested_derivative(dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
126+
(is_delay_var(iv, dv) || occursin(iv, iv_from_nested_derivative(dv))) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv."))
112127
end
113128
end
114129

test/discretesystem.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,37 @@ prob_map = DiscreteProblem(sir_map!,u0,tspan,p);
5656
sol_map2 = solve(prob_map,FunctionMap());
5757

5858
@test Array(sol_map) Array(sol_map2)
59+
60+
# Delayed difference equation
61+
@parameters t
62+
@variables x(..) y(..) z(t)
63+
D1 = Difference(t; dt=1.5)
64+
D2 = Difference(t; dt=2)
65+
66+
@test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(x(t-2)))
67+
@test ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(y(t-1)))
68+
@test !ModelingToolkit.is_delay_var(Symbolics.value(t), Symbolics.value(z))
69+
@test_throws ErrorException ModelingToolkit.get_delay_val(Symbolics.value(t), Symbolics.arguments(Symbolics.value(x(t+2)))[1])
70+
@test_throws ErrorException z(t)
71+
72+
# Equations
73+
eqs = [
74+
D1(x(t)) ~ 0.4x(t) + 0.3x(t-1.5) + 0.1x(t-3),
75+
D2(y(t)) ~ 0.3y(t) + 0.7y(t-2) + 0.1z,
76+
]
77+
78+
# System
79+
@named sys = DiscreteSystem(eqs,t,[x(t),x(t-1.5),x(t-3),y(t),y(t-2),z],[])
80+
81+
eqs2, max_delay = ModelingToolkit.linearize_eqs(sys; return_max_delay=true)
82+
83+
@test max_delay[Symbolics.operation(Symbolics.value(x(t)))] 3
84+
@test max_delay[Symbolics.operation(Symbolics.value(y(t)))] 2
85+
86+
linearized_eqs = [
87+
eqs
88+
x(t - 3.0) ~ x(t - 1.5)
89+
x(t - 1.5) ~ x(t)
90+
y(t - 2.0) ~ y(t)
91+
]
92+
@test all(eqs2 .== linearized_eqs)

0 commit comments

Comments
 (0)