Skip to content

Commit fc2a309

Browse files
committed
fix tests and shift2term
1 parent 2a97325 commit fc2a309

File tree

3 files changed

+25
-26
lines changed

3 files changed

+25
-26
lines changed

src/structural_transformation/utils.jl

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,21 @@ end
469469
Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
470470
"""
471471
function shift2term(var)
472-
backshift = operation(var).steps
473-
iv = operation(var).t
472+
op = operation(var)
473+
iv = op.t
474+
arg = only(arguments(var))
475+
is_lowered = !isnothing(ModelingToolkit.getunshifted(arg))
476+
477+
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
478+
474479
num = join(Char(0x2080 + d) for d in reverse!(digits(-backshift))) # subscripted number, e.g. ₁
475480
ds = join([Char(0x209c), Char(0x208b), num])
476481
# Char(0x209c) = ₜ
477482
# Char(0x208b) = ₋ (subscripted minus)
478483

479-
O = only(arguments(var))
484+
O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
480485
oldop = operation(O)
481-
newname = Symbol(string(nameof(oldop)), ds)
486+
newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : Symbol(string(nameof(oldop)))
482487

483488
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
484489
newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
@@ -487,24 +492,6 @@ function shift2term(var)
487492
return newvar
488493
end
489494

490-
function term2shift(var)
491-
var = Symbolics.unwrap(var)
492-
name = Symbolics.getname(var)
493-
O = only(arguments(var))
494-
oldop = operation(O)
495-
iv = only(arguments(x))
496-
# Split on ₋
497-
if occursin(Char(0x208b), name)
498-
substrings = split(name, Char(0x208b))
499-
shift = last(split(name, Char(0x208b)))
500-
newname = join(substrings[1:end-1])[1:end-1]
501-
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), Symbolics.children(O), Symbolics.metadata(O))
502-
return Shift(iv, -shift)(newvar)
503-
else
504-
return var
505-
end
506-
end
507-
508495
function isdoubleshift(var)
509496
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
510497
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)

src/systems/discrete_system/discrete_system.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,15 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
269269
for k in collect(keys(u0map))
270270
v = u0map[k]
271271
if !((op = operation(k)) isa Shift)
272-
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
272+
isnothing(getunshifted(k)) && error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
273+
274+
updated[Shift(iv, 1)(k)] = v
273275
elseif op.steps > 0
274276
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(only(arguments(k)))).")
277+
else
278+
updated[Shift(iv, op.steps + 1)(only(arguments(k)))] = v
275279
end
276280

277-
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
278281
end
279282
for var in unknowns(sys)
280283
op = operation(var)

test/discrete_system.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ end
282282
prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
283283
@test prob[x] == 3.0
284284
@test prob[x(k - 1)] == 2.0
285+
@variables xₜ₋₁(t)
285286
@test prob[xₜ₋₁] == 2.0
286287

287288
# Test initial assignment with lowered variable
288-
@variables xₜ₋₁(t)
289289
prob = DiscreteProblem(de, [xₜ₋₁(k-1) => 4.0], (0, 10))
290290
@test prob[x(k-1)] == prob[xₜ₋₁] == 1.0
291291
@test prob[x] == 5.
@@ -298,16 +298,17 @@ end
298298

299299
# Test non-assigned initials are given default value
300300
@variables x(t) = 2.
301+
@mtkbuild de = DiscreteSystem([x ~ x(k-1) + x(k-2)*x(k-3)], t)
301302
prob = DiscreteProblem(de, [x(k-3) => 12.], (0, 10))
302303
@test prob[x] == 26.0
303304
@test prob[x(k-1)] == 2.0
304305
@test prob[x(k-2)] == 2.0
305306

306307
# Elaborate test
308+
@variables xₜ₋₂(t) zₜ₋₁(t) z(t)
307309
eqs = [x ~ x(k-1) + z(k-2),
308310
z ~ x(k-2) * x(k-3) - z(k-1)^2]
309311
@mtkbuild de = DiscreteSystem(eqs, t)
310-
@variables xₜ₋₂(t) zₜ₋₁(t)
311312
u0 = [x(k-1) => 3,
312313
xₜ₋₂(k-1) => 4,
313314
x(k-2) => 1,
@@ -316,4 +317,12 @@ end
316317
prob = DiscreteProblem(de, u0, (0, 10))
317318
@test prob[x] == 15
318319
@test prob[z] == -21
320+
321+
import ModelingToolkit: shift2term
322+
# unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z
323+
vars = ModelingToolkit.value.(unknowns(de))
324+
@test isequal(shift2term(Shift(t, 1)(vars[1])), vars[2])
325+
@test isequal(shift2term(Shift(t, 1)(vars[4])), vars[1])
326+
@test isequal(shift2term(Shift(t, -1)(vars[5])), vars[3])
327+
@test isequal(shift2term(Shift(t, -2)(vars[2])), vars[4])
319328
end

0 commit comments

Comments
 (0)