Skip to content

Commit 6fb59de

Browse files
committed
feat: implement
1 parent 8db7590 commit 6fb59de

File tree

5 files changed

+71
-3
lines changed

5 files changed

+71
-3
lines changed

src/structural_transformation/utils.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,38 @@ function simplify_shifts(var)
507507
unwrap(var).metadata)
508508
end
509509
end
510+
511+
function distribute_shift(var)
512+
var = unwrap(var)
513+
var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)
514+
515+
ModelingToolkit.hasshift(var) || return var
516+
shift = operation(var)
517+
shift isa Shift || return var
518+
519+
shift = operation(var)
520+
expr = only(arguments(var))
521+
if expr isa Equation
522+
return distribute_shift(shift(expr.lhs)) ~ distribute_shift(shift(expr.rhs))
523+
end
524+
shiftexpr = _distribute_shift(expr, shift)
525+
return simplify_shifts(shiftexpr)
526+
end
527+
528+
function _distribute_shift(expr, shift)
529+
if iscall(expr)
530+
op = operation(expr)
531+
args = arguments(expr)
532+
533+
if ModelingToolkit.isvariable(expr)
534+
(length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) : (return expr)
535+
elseif op isa Shift
536+
return shift(expr)
537+
else
538+
return maketerm(typeof(expr), operation(expr), Base.Fix2(_distribute_shift, shift).(args),
539+
unwrap(expr).metadata)
540+
end
541+
else
542+
return expr
543+
end
544+
end

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ function generate_function(
265265
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
266266
iv = get_iv(sys)
267267
exprs = map(equations(sys)) do eq
268-
_iszero(eq.lhs) ? eq.rhs : (simplify_shifts(Shift(iv, -1)(eq.rhs)) - simplify_shifts(Shift(iv, -1)(eq.lhs)))
268+
_iszero(eq.lhs) ? eq.rhs : (distribute_shift(Shift(iv, -1)(eq.rhs)) - distribute_shift(Shift(iv, -1)(eq.lhs)))
269269
end
270270

271271
u_next = dvs

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ function TearingState(sys; quick_cancel = false, check = true)
440440
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
441441
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
442442
Any[])
443-
if sys isa DiscreteSystem
443+
if sys isa AbstractDiscreteSystem
444444
ts = shift_discrete_system(ts)
445445
end
446446
return ts

test/implicit_discrete_system.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ resid = rand(2)
1818
f(resid, u_next, [2.,3.], [], t)
1919
@test resid [3., -3.]
2020

21-
# Initialization cases.
2221
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
2322
@test prob.u0 == [3., 1.]
2423
prob = ImplicitDiscreteProblem(sys, [], tspan)
@@ -28,3 +27,17 @@ prob = ImplicitDiscreteProblem(sys, [], tspan)
2827
@test_throws ErrorException prob = ImplicitDiscreteProblem(sys, [], tspan)
2928

3029
# Test solvers
30+
@testset "System with algebraic equations" begin
31+
@variables x(t) y(t)
32+
eqs = [x(k) ~ x(k-1) + x(k-2),
33+
x^2 ~ 1 - y^2]
34+
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
35+
end
36+
37+
@testset "System with algebraic equations, implicit difference equations, explicit difference equations" begin
38+
@variables x(t) y(t)
39+
eqs = [x(k) ~ x(k-1) + x(k-2),
40+
y(k) ~ x(k) + x(k-2)*y(k-1)]
41+
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
42+
end
43+

test/structural_transformation/utils.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,23 @@ end
162162
structural_simplify(sys; additional_passes = [pass])
163163
@test value[] == 1
164164
end
165+
166+
@testset "Distribute shifts" begin
167+
@variables x(t) y(t) z(t)
168+
@parameters a b c
169+
170+
# Expand shifts
171+
@test isequal(ST.distribute_shift(Shift(t, -1)(x + y)), Shift(t, -1)(x) + Shift(t, -1)(y))
172+
173+
expr = a * Shift(t, -2)(x) + Shift(t, 2)(y) + b
174+
@test isequal(ST.simplify_shifts(ST.distribute_shift(Shift(t, 2)(expr))),
175+
a*x + Shift(t, 4)(y) + b)
176+
@test isequal(ST.distribute_shift(Shift(t, 2)(exp(z))), exp(Shift(t, 2)(z)))
177+
@test isequal(ST.distribute_shift(Shift(t, 2)(exp(a) + b)), exp(a) + b)
178+
179+
expr = a^x - log(b*y) + z*x
180+
@test isequal(ST.distribute_shift(Shift(t, -3)(expr)), a^(Shift(t, -3)(x)) - log(b * Shift(t, -3)(y)) + Shift(t, -3)(z)*Shift(t, -3)(x))
181+
182+
expr = x(k+1) ~ x + x(k-1)
183+
@test isequal(ST.distribute_shift(Shift(t, -1)(expr)), x ~ x(k-1) + x(k-2))
184+
end

0 commit comments

Comments
 (0)