Skip to content

Commit c504aef

Browse files
committed
fix: make the codegen work with shifted observables
1 parent 00a8222 commit c504aef

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
558558
end
559559

560560
total_sub[simplify_shifts(neweq.lhs)] = neweq.rhs
561+
# Substitute unshifted variables x(k), y(k) on RHS of implicit equations
562+
if is_only_discrete(structure)
563+
var_to_diff[iv] === nothing && (total_sub[var] = neweq.rhs)
564+
end
561565
push!(diff_eqs, neweq)
562566
push!(diffeq_idxs, ieq)
563567
push!(diff_vars, diff_to_var[iv])

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,20 @@ end
264264
function generate_function(
265265
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
266266
iv = get_iv(sys)
267+
# Algebraic equations get shifted forward 1, to match with differential equations
267268
exprs = map(equations(sys)) do eq
268-
_iszero(eq.lhs) ? eq.rhs : (distribute_shift(Shift(iv, -1)(eq.rhs)) - distribute_shift(Shift(iv, -1)(eq.lhs)))
269+
_iszero(eq.lhs) ? distribute_shift(Shift(iv, 1)(eq.rhs)) : (eq.rhs - eq.lhs)
269270
end
270271

271-
u_next = dvs
272-
u = map(Shift(iv, -1), u_next)
273-
build_function_wrapper(sys, exprs, u_next, u, ps..., iv; p_start = 3, kwargs...)
272+
# Handle observables in algebraic equations, since they are shifted
273+
obs = observed(sys)
274+
shifted_obs = [distribute_shift(Shift(iv, 1)(eq)) for eq in obs]
275+
obsidxs = observed_equations_used_by(sys, exprs; obs = shifted_obs)
276+
extra_assignments = [Assignment(shifted_obs[i].lhs, shifted_obs[i].rhs) for i in obsidxs]
277+
278+
u_next = map(Shift(iv, 1), dvs)
279+
u = dvs
280+
build_function_wrapper(sys, exprs, u_next, u, ps..., iv; p_start = 3, extra_assignments, kwargs...)
274281
end
275282

276283
function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
@@ -279,15 +286,22 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
279286
for k in collect(keys(u0map))
280287
v = u0map[k]
281288
if !((op = operation(k)) isa Shift)
282-
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
289+
isnothing(getunshifted(k)) &&
290+
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
291+
292+
updated[k] = v
293+
elseif op.steps > 0
294+
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k))).")
295+
else
296+
updated[k] = v
283297
end
284-
updated[shift2term(k)] = v
285298
end
286299
for var in unknowns(sys)
287300
op = operation(var)
288-
haskey(updated, var) && continue
289301
root = getunshifted(var)
302+
shift = getshift(var)
290303
isnothing(root) && continue
304+
(haskey(updated, Shift(iv, shift)(root)) || haskey(updated, var)) && continue
291305
haskey(defs, root) || error("Initial condition for $var not provided.")
292306
updated[var] = defs[root]
293307
end
@@ -317,7 +331,9 @@ function SciMLBase.ImplicitDiscreteProblem(
317331
u0map = to_varmap(u0map, dvs)
318332
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
319333
f, u0, p = process_SciMLProblem(
320-
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
334+
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
335+
336+
kwargs = filter_kwargs(kwargs)
321337
ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
322338
end
323339

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ function TearingState(sys; quick_cancel = false, check = true)
440440
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
441441
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
442442
Any[])
443-
if sys isa AbstractDiscreteSystem
443+
if sys isa DiscreteSystem
444444
ts = shift_discrete_system(ts)
445445
end
446446
return ts

test/implicit_discrete_system.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ using StableRNGs
55
k = ShiftIndex(t)
66
rng = StableRNG(22525)
77

8-
# Shift(t, -1)(x(t)) - x_{t-1}(t)
9-
# -3 - x(t) + x(t)*x_{t-1}
108
@testset "Correct ImplicitDiscreteFunction" begin
119
@variables x(t) = 1
1210
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
1311
tspan = (0, 10)
12+
13+
# u[2] - u_next[1]
14+
# -3 - u_next[2] + u_next[2]*u_next[1]
1415
f = ImplicitDiscreteFunction(sys)
1516
u_next = [3., 1.5]
1617
@test f(u_next, [2.,3.], [], t) [0., 0.]
@@ -30,7 +31,6 @@ rng = StableRNG(22525)
3031
@test_throws ErrorException prob = ImplicitDiscreteProblem(sys, [], tspan)
3132
end
3233

33-
# Test solvers
3434
@testset "System with algebraic equations" begin
3535
@variables x(t) y(t)
3636
eqs = [x(k) ~ x(k-1) + x(k-2),
@@ -51,13 +51,24 @@ end
5151
end
5252

5353
# Initialization is satisfied.
54-
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
54+
prob = ImplicitDiscreteProblem(sys, [x(k-1) => .3, x(k-2) => .4], (0, 10), guesses = [y => 1])
5555
@test (prob.u0[1] + prob.u0[2])^2 + prob.u0[3]^2 1
5656
end
5757

58-
@testset "System with algebraic equations, implicit difference equations, explicit difference equations" begin
59-
@variables x(t) y(t)
58+
@testset "Handle observables in function codegen" begin
59+
# Observable appears in differential equation
60+
@variables x(t) y(t) z(t)
6061
eqs = [x(k) ~ x(k-1) + x(k-2),
61-
y(k) ~ x(k) + x(k-2)*y(k-1)]
62+
y(k) ~ x(k) + x(k-2)*z(k-1),
63+
x + y + z ~ 2]
64+
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
65+
@test length(unknowns(sys)) == length(equations(sys)) == 3
66+
@test occursin("var\"y(t)\"", string(ImplicitDiscreteFunctionExpr(sys)))
67+
68+
# Shifted observable that appears in algebraic equation is properly handled.
69+
eqs = [z(k) ~ x(k) + sin(x(k)),
70+
y(k) ~ x(k-1) + x(k-2),
71+
z(k) * x(k) ~ 3]
6272
@mtkbuild sys = ImplicitDiscreteSystem(eqs, t)
73+
@test occursin("var\"Shift(t, 1)(z(t))\"", string(ImplicitDiscreteFunctionExpr(sys)))
6374
end

0 commit comments

Comments
 (0)