Skip to content

Commit 5f231c7

Browse files
committed
fix: fix initialization cases
1 parent 0232a0a commit 5f231c7

File tree

9 files changed

+62
-20
lines changed

9 files changed

+62
-20
lines changed

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ abstract type AbstractTimeIndependentSystem <: AbstractSystem end
123123
abstract type AbstractODESystem <: AbstractTimeDependentSystem end
124124
abstract type AbstractMultivariateSystem <: AbstractSystem end
125125
abstract type AbstractOptimizationSystem <: AbstractTimeIndependentSystem end
126+
abstract type AbstractDiscreteSystem <: AbstractTimeDependentSystem end
126127

127128
function independent_variable end
128129

@@ -165,6 +166,7 @@ include("systems/diffeqs/modelingtoolkitize.jl")
165166
include("systems/diffeqs/basic_transformations.jl")
166167

167168
include("systems/discrete_system/discrete_system.jl")
169+
include("systems/discrete_system/implicit_discrete_system.jl")
168170

169171
include("systems/jumps/jumpsystem.jl")
170172

@@ -230,6 +232,7 @@ export DAEFunctionExpr, DAEProblemExpr
230232
export SDESystem, SDEFunction, SDEFunctionExpr, SDEProblemExpr
231233
export SystemStructure
232234
export DiscreteSystem, DiscreteProblem, DiscreteFunction, DiscreteFunctionExpr
235+
export ImplicitDiscreteSystem, ImplicitDiscreteProblem, ImplicitDiscreteFunction, ImplicitDiscreteFunctionExpr
233236
export JumpSystem
234237
export ODEProblem, SDEProblem
235238
export NonlinearFunction, NonlinearFunctionExpr

src/structural_transformation/StructuralTransformations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ export torn_system_jacobian_sparsity
6363
export full_equations
6464
export but_ordered_incidence, lowest_order_variable_mask, highest_order_variable_mask
6565
export computed_highest_diff_variables
66-
export shift2term, lower_shift_varname
66+
export shift2term, lower_shift_varname, simplify_shifts
6767

6868
include("utils.jl")
6969
include("pantelides.jl")

src/structural_transformation/utils.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,10 @@ end
452452
# For discrete variables. Turn Shift(t, k)(x(t)) into xₜ₋ₖ(t)
453453
function lower_shift_varname(var, iv)
454454
op = operation(var)
455-
op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
456-
if op.steps < 0
455+
if op isa Shift && op.steps < 0
457456
return shift2term(var)
458457
else
459-
return var
458+
return Shift(iv, 0)(var, true)
460459
end
461460
end
462461

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ eqs = [x(k+1) ~ σ*(y-x),
1717
@named de = DiscreteSystem(eqs)
1818
```
1919
"""
20-
struct DiscreteSystem <: AbstractTimeDependentSystem
20+
struct DiscreteSystem <: AbstractDiscreteSystem
2121
"""
2222
A tag for the system. If two systems have the same tag, then they are
2323
structurally identical.
@@ -237,6 +237,8 @@ function DiscreteSystem(eqs, iv; kwargs...)
237237
collect(allunknowns), collect(new_ps); kwargs...)
238238
end
239239

240+
DiscreteSystem(eq::Equation, args...; kwargs...) = DiscreteSystem([eq], args...; kwargs...)
241+
240242
function flatten(sys::DiscreteSystem, noeqs = false)
241243
systems = get_systems(sys)
242244
if isempty(systems)
@@ -271,14 +273,16 @@ function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
271273
if !((op = operation(k)) isa Shift)
272274
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
273275
end
274-
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
276+
k_next = Shift(iv, op.steps + 1)(arguments(k)[1])
277+
operation(k_next) isa Shift ? updated[shift2term(k_next)] = v :
278+
updated[k_next] = v
275279
end
276280
for var in unknowns(sys)
277281
op = operation(var)
278282
haskey(updated, var) && continue
279283
root = getunshifted(var)
280284
isnothing(root) && continue
281-
haskey(defs, root) || error("Initial condition for $root not provided.")
285+
haskey(defs, root) || error("Initial condition for $var not provided.")
282286
updated[var] = defs[root]
283287
end
284288
return updated

src/systems/discrete_system/implicit_discrete_system.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ function ImplicitDiscreteSystem(eqs, iv; kwargs...)
239239
collect(allunknowns), collect(new_ps); kwargs...)
240240
end
241241

242+
ImplicitDiscreteSystem(eq::Equation, args...; kwargs...) = ImplicitDiscreteSystem([eq], args...; kwargs...)
243+
242244
function flatten(sys::ImplicitDiscreteSystem, noeqs = false)
243245
systems = get_systems(sys)
244246
if isempty(systems)
@@ -261,10 +263,14 @@ end
261263

262264
function generate_function(
263265
sys::ImplicitDiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
264-
exprs = [eq.lhs - eq.rhs for eq in equations(sys)]
265-
u = dvs
266-
u_next = map(Shift(iv, 1), u)
267-
generate_custom_function(sys, exprs, u_next, u, ps..., get_iv(sys); kwargs...)
266+
iv = get_iv(sys)
267+
exprs = map(equations(sys)) do eq
268+
_iszero(eq.lhs) ? eq.rhs : (simplify_shifts(Shift(iv, -1)(eq.rhs)) - simplify_shifts(Shift(iv, -1)(eq.lhs)))
269+
end
270+
271+
u_next = dvs
272+
u = map(Shift(iv, -1), u_next)
273+
build_function_wrapper(sys, exprs, u_next, u, ps..., iv; p_start = 3, kwargs...)
268274
end
269275

270276
function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
@@ -275,13 +281,13 @@ function shift_u0map_forward(sys::ImplicitDiscreteSystem, u0map, defs)
275281
if !((op = operation(k)) isa Shift)
276282
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
277283
end
278-
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
284+
updated[shift2term(k)] = v
279285
end
280286
for var in unknowns(sys)
281287
op = operation(var)
282-
op isa Shift || continue
283288
haskey(updated, var) && continue
284-
root = first(arguments(var))
289+
root = getunshifted(var)
290+
isnothing(root) && continue
285291
haskey(defs, root) || error("Initial condition for $var not provided.")
286292
updated[var] = defs[root]
287293
end
@@ -301,7 +307,7 @@ function SciMLBase.ImplicitDiscreteProblem(
301307
kwargs...
302308
)
303309
if !iscomplete(sys)
304-
error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`")
310+
error("A completed `ImplicitDiscreteSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `ImplicitDiscreteProblem`.")
305311
end
306312
dvs = unknowns(sys)
307313
ps = parameters(sys)
@@ -312,8 +318,7 @@ function SciMLBase.ImplicitDiscreteProblem(
312318
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
313319
f, u0, p = process_SciMLProblem(
314320
ImplicitDiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
315-
u0 = f(u0, p, tspan[1])
316-
NonlinearProblem(f, u0, tspan, p; kwargs...)
321+
ImplicitDiscreteProblem(f, u0, tspan, p; kwargs...)
317322
end
318323

319324
function SciMLBase.ImplicitDiscreteFunction(sys::ImplicitDiscreteSystem, args...; kwargs...)

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ function structural_simplify(
4242
if newsys isa DiscreteSystem &&
4343
any(eq -> symbolic_type(eq.lhs) == NotSymbolic(), equations(newsys))
4444
error("""
45-
Encountered algebraic equations when simplifying discrete system. This is \
46-
not yet supported.
45+
Encountered algebraic equations when simplifying discrete system. Please construct \
46+
an ImplicitDiscreteSystem instead.
4747
""")
4848
end
4949
for pass in additional_passes

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ function TearingState(sys; quick_cancel = false, check = true)
438438

439439
ts = TearingState(sys, fullvars,
440440
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
441-
complete(graph), nothing, var_types, sys isa DiscreteSystem),
441+
complete(graph), nothing, var_types, sys isa AbstractDiscreteSystem),
442442
Any[])
443443
if sys isa DiscreteSystem
444444
ts = shift_discrete_system(ts)

test/implicit_discrete_system.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using ModelingToolkit, Test
2+
using ModelingToolkit: t_nounits as t
3+
4+
k = ShiftIndex(t)
5+
@variables x(t) = 1
6+
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
7+
tspan = (0, 10)
8+
9+
# Shift(t, -1)(x(t)) - x_{t-1}(t)
10+
# -3 - x(t) + x(t)*x_{t-1}
11+
f = ImplicitDiscreteFunction(sys)
12+
u_next = [3., 1.5]
13+
@test f(u_next, [2.,3.], [], t) [0., 0.]
14+
u_next = [0., 0.]
15+
@test f(u_next, [2.,3.], [], t) [3., -3.]
16+
17+
resid = rand(2)
18+
f(resid, u_next, [2.,3.], [], t)
19+
@test resid [3., -3.]
20+
21+
# Initialization cases.
22+
prob = ImplicitDiscreteProblem(sys, [x(k-1) => 3.], tspan)
23+
@test prob.u0 == [3., 1.]
24+
prob = ImplicitDiscreteProblem(sys, [], tspan)
25+
@test prob.u0 == [1., 1.]
26+
@variables x(t)
27+
@mtkbuild sys = ImplicitDiscreteSystem([x(k) ~ x(k)*x(k-1) - 3], t)
28+
@test_throws ErrorException prob = ImplicitDiscreteProblem(sys, [], tspan)
29+
30+
# Test solvers

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ end
8080
@safetestset "Variable Metadata Test" include("test_variable_metadata.jl")
8181
@safetestset "OptimizationSystem Test" include("optimizationsystem.jl")
8282
@safetestset "Discrete System" include("discrete_system.jl")
83+
@safetestset "Implicit Discrete System" include("implicit_discrete_system.jl")
8384
@safetestset "SteadyStateSystem Test" include("steadystatesystems.jl")
8485
@safetestset "SDESystem Test" include("sdesystem.jl")
8586
@safetestset "DDESystem Test" include("dde.jl")

0 commit comments

Comments
 (0)