Skip to content

Commit 00a8222

Browse files
committed
add shift2term for positive shifts
1 parent a85c1c5 commit 00a8222

File tree

4 files changed

+71
-26
lines changed

4 files changed

+71
-26
lines changed

src/discretedomain.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ Base.literal_pow(f::typeof(^), D::Shift, ::Val{n}) where {n} = Shift(D.t, D.step
7070

7171
hasshift(eq::Equation) = hasshift(eq.lhs) || hasshift(eq.rhs)
7272

73+
"""
74+
Next(x)
75+
76+
An alias for Shift(t, 1)(x).
77+
"""
78+
Next(x) = Shift(t, 1)(x)
79+
"""
80+
Prev(x)
81+
82+
An alias for Shift(t, -1)(x).
83+
"""
84+
Prev(x) = Shift(t, -1)(x)
85+
7386
"""
7487
hasshift(O)
7588

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export torn_system_jacobian_sparsity
6565
export full_equations
6666
export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask
6767
export computed_highest_diff_variables
68-
export shift2term, lower_shift_varname, simplify_shifts
68+
export shift2term, lower_shift_varname, simplify_shifts, distribute_shift
6969

7070
include("utils.jl")
7171
include("pantelides.jl")

src/structural_transformation/utils.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ Handle renaming variable names for discrete structural simplification. Three cas
457457
"""
458458
function lower_shift_varname(var, iv)
459459
op = operation(var)
460-
if op isa Shift && op.steps < 0
460+
if op isa Shift
461461
return shift2term(var)
462462
else
463463
return Shift(iv, 0)(var, true)
@@ -475,10 +475,14 @@ function shift2term(var)
475475

476476
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
477477

478-
num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
479-
ds = join([Char(0x209c), Char(0x208b), num])
480-
# Char(0x209c) = ₜ
481478
# Char(0x208b) = ₋ (subscripted minus)
479+
# Char(0x208a) = ₊ (subscripted plus)
480+
pm = backshift > 0 ? Char(0x208a) : Char(0x208b)
481+
# subscripted number, e.g. ₁
482+
num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift))))
483+
# Char(0x209c) = ₜ
484+
# ds = ₜ₋₁
485+
ds = join([Char(0x209c), pm, num])
482486

483487
O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
484488
oldop = operation(O)
@@ -498,6 +502,9 @@ function isdoubleshift(var)
498502
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
499503
end
500504

505+
"""
506+
Simplify multiple shifts: Shift(t, k1)(Shift(t, k2)(x)) becomes Shift(t, k1+k2)(x).
507+
"""
501508
function simplify_shifts(var)
502509
ModelingToolkit.hasshift(var) || return var
503510
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
@@ -518,6 +525,11 @@ function simplify_shifts(var)
518525
end
519526
end
520527

528+
"""
529+
Distribute a shift applied to a whole expression or equation.
530+
Shift(t, 1)(x + y) will become Shift(t, 1)(x) + Shift(t, 1)(y).
531+
Only shifts variables whose independent variable is the same t that appears in the Shift (i.e. constants, time-independent parameters, etc. do not get shifted).
532+
"""
521533
function distribute_shift(var)
522534
var = unwrap(var)
523535
var isa Equation && return distribute_shift(var.lhs) ~ distribute_shift(var.rhs)

test/implicit_discrete_system.jl

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,58 @@
11
using ModelingToolkit, Test
22
using ModelingToolkit: t_nounits as t
3+
using StableRNGs
34

45
k = ShiftIndex(t)
5-
@variables x(t) = 1
6-
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
7-
tspan = (0, 10)
6+
rng = StableRNG(22525)
87

98
# Shift(t, -1)(x(t)) - x_{t-1}(t)
109
# -3 - x(t) + x(t)*x_{t-1}
11-
f = ImplicitDiscreteFunction(sys)
12-
u_next = [3., 1.5]
13-
@test f(u_next, [2.,3.], [], t) [0., 0.]
14-
u_next = [0., 0.]
15-
@test f(u_next, [2.,3.], [], t) [3., -3.]
16-
17-
resid = rand(2)
18-
f(resid, u_next, [2.,3.], [], t)
19-
@test resid [3., -3.]
20-
21-
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
22-
@test prob.u0 == [3., 1.]
23-
prob = ImplicitDiscreteProblem(sys, [], tspan)
24-
@test prob.u0 == [1., 1.]
25-
@variables x(t)
26-
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
27-
@test_throws ErrorException prob = ImplicitDiscreteProblem(sys, [], tspan)
10+
@testset "Correct ImplicitDiscreteFunction" begin
11+
@variables x(t) = 1
12+
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
13+
tspan = (0, 10)
14+
f = ImplicitDiscreteFunction(sys)
15+
u_next = [3., 1.5]
16+
@test f(u_next, [2.,3.], [], t) [0., 0.]
17+
u_next = [0., 0.]
18+
@test f(u_next, [2.,3.], [], t) [3., -3.]
19+
20+
resid = rand(2)
21+
f(resid, u_next, [2.,3.], [], t)
22+
@test resid [3., -3.]
23+
24+
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
25+
@test prob.u0 == [3., 1.]
26+
prob = ImplicitDiscreteProblem(sys, [], tspan)
27+
@test prob.u0 == [1., 1.]
28+
@variables x(t)
29+
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
30+
@test_throws ErrorException prob = ImplicitDiscreteProblem(sys, [], tspan)
31+
end
2832

2933
# Test solvers
3034
@testset "System with algebraic equations" begin
3135
@variables x(t) y(t)
3236
eqs = [x(k) ~ x(k-1) + x(k-2),
3337
x^2 ~ 1 - y^2]
3438
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
39+
f = ImplicitDiscreteFunction(sys)
40+
41+
function correct_f(u_next, u, p, t)
42+
[u[2] - u_next[1],
43+
u[1] + u[2] - u_next[2],
44+
1 - (u_next[1]+u_next[2])^2 - u_next[3]^2]
45+
end
46+
47+
for _ in 1:10
48+
u_next = rand(rng, 3)
49+
u = rand(rng, 3)
50+
@test correct_f(u_next, u, [], 0.) f(u_next, u, [], 0.)
51+
end
52+
53+
# Initialization is satisfied.
54+
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
55+
@test (prob.u0[1] + prob.u0[2])^2 + prob.u0[3]^2 1
3556
end
3657

3758
@testset "System with algebraic equations, implicit difference equations, explicit difference equations" begin
@@ -40,4 +61,3 @@ end
4061
y(k) ~ x(k) + x(k-2)*y(k-1)]
4162
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
4263
end
43-

0 commit comments

Comments
 (0)