Skip to content

Commit 56829f7

Browse files
committed
up
1 parent 91acf91 commit 56829f7

File tree

2 files changed

+73
-52
lines changed

2 files changed

+73
-52
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ function substitute_dummy_derivatives!(ts::TearingState, neweqs, dummy_sub, var_
252252
@unpack fullvars, sys, structure = ts
253253
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
254254
diff_to_var = invview(var_to_diff)
255+
iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
255256

256257
for var in 1:length(fullvars)
257258
dv = var_to_diff[var]
@@ -337,31 +338,32 @@ variables and equations, don't add them when they already exist.
337338
###### DISCRETE SYSTEMS #######
338339
339340
Documenting the differences to structural simplification for discrete systems:
340-
In discrete systems the lowest-order term is x_k-i, instead of x(t).
341+
342+
1. In discrete systems the lowest-order term is Shift(t, k)(x(t)), instead of x(t). We need to substitute the k-1 lowest order terms instead of the k-1 highest order terms.
341343
342344
The orders will also be off by one. The reason this is is that the dynamics of
343345
the system should be given in terms of Shift(t, 1)(x(t), x(t-1), ...). But
344346
having the observables be indexed by the next time step is not so nice. So we
345-
handle the shifts in the renaming, rather than explicitly.
347+
handle the shifts in the renaming.
346348
347349
The substitution should look like the following:
348350
x(t) -> Shift(t, 1)(x(t))
349-
Shift(t, -1)(x(t)) -> x(t)
351+
Shift(t, -1)(x(t)) -> Shift(t, 0)(x(t))
350352
Shift(t, -2)(x(t)) -> x_{t-1}(t)
351353
Shift(t, -3)(x(t)) -> x_{t-2}(t)
352354
and so on...
353355
354356
In the implicit discrete case this shouldn't happen. The simplification should
355357
look like a NonlinearSystem.
356358
357-
For discrete systems Shift(t, 2)(x(t)) is not equivalent to Shift(t, 1)(Shift(t,1)(x(t))
359+
2. For discrete systems Shift(t, 2)(x(t)) cannot be substituted as Shift(t, 1)(Shift(t,1)(x(t)).
358360
This is different from the continuous case where D(D(x)) can be substituted for
359361
by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
360-
total_sub dict is updated at the time that the renamed variables are written,
362+
shift_sub dict is updated at the time that the renamed variables are written,
361363
inside the loop where new variables are generated.
362364
"""
363-
function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching;
364-
is_discrete = false, mm = nothing)
365+
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
366+
is_discrete = false, mm = nothing, shift_sub = nothing)
365367
@unpack fullvars, sys, structure = ts
366368
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
367369
eq_var_matching = invview(var_eq_matching)
@@ -391,23 +393,24 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
391393
# - uv is the index of the highest-order variable (x(t))
392394
for v in 1:length(var_to_diff)
393395
dv = var_to_diff[v]
394-
println()
395-
@show (v, dv)
396+
396397
if is_discrete
397398
x = fullvars[v]
398399
op = operation(x)
399400
(low, uv) = idx_to_lowest_shift[v]
400401

401402
# If v is unshifted (i.e. x(t)), then substitute the lowest-shift variable
402-
if !(op isa Shift) && (low != 0)
403+
if !(op isa Shift)
403404
dv = findfirst(_x -> isequal(_x, Shift(iv, low)(x)), fullvars)
404405
end
406+
dx = fullvars[dv]
407+
order, lv = var_order(diff_to_var, dv)
408+
@show fullvars[uv]
409+
x_t = lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv])
410+
shift_sub[dx] = x_t
411+
(var_eq_matching[dv] isa Int) ? continue : @goto DISCRETE_VARIABLE
405412
end
406413
dv isa Int || continue
407-
408-
@show dv
409-
@show var_eq_matching[dv]
410-
@show fullvars
411414
solved = var_eq_matching[dv] isa Int
412415
solved && continue
413416

@@ -430,37 +433,32 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
430433
end
431434
end
432435

433-
dx = fullvars[dv]
434436
# add `x_t`
437+
dx = fullvars[dv]
435438
order, lv = var_order(diff_to_var, dv)
436-
x_t = is_discrete ? lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv]) :
437-
lower_name(fullvars[lv], iv, order)
438-
@show dx, x_t
439-
push!(fullvars, simplify_shifts(x_t))
439+
x_t = lower_name(fullvars[lv], iv, order)
440+
441+
@label DISCRETE_VARIABLE
442+
push!(fullvars, x_t)
440443
v_t = length(fullvars)
441444
v_t_idx = add_vertex!(var_to_diff)
442445
add_vertex!(graph, DST)
443446
# TODO: do we care about solvable_graph? We don't use them after
444447
# `dummy_derivative_graph`.
445448
add_vertex!(solvable_graph, DST)
446449
push!(var_eq_matching, unassigned)
447-
@assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
448-
length(var_eq_matching)
449450

450451
# Add discrete substitutions to total_sub directly.
451452
is_discrete && begin
452453
idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
453-
if operation(dx) isa Shift
454-
total_sub[dx] = x_t
455-
for e in 𝑑neighbors(graph, dv)
456-
add_edge!(graph, e, v_t)
457-
rem_edge!(graph, e, dv)
458-
end
459-
# Do not add the lowest-order substitution as an equation, just substitute
460-
!(operation(x) isa Shift) && begin
461-
var_to_diff[v_t] = var_to_diff[dv]
462-
continue
463-
end
454+
for e in 𝑑neighbors(graph, dv)
455+
add_edge!(graph, e, v_t)
456+
rem_edge!(graph, e, dv)
457+
end
458+
# Do not add the lowest-order substitution as an equation, just substitute
459+
!(operation(x) isa Shift) && begin
460+
var_to_diff[v_t] = var_to_diff[dv]
461+
continue
464462
end
465463
end
466464

@@ -472,14 +470,22 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
472470
add_edge!(graph, dummy_eq, v_t)
473471
add_vertex!(solvable_graph, SRC)
474472
add_edge!(solvable_graph, dummy_eq, dv)
475-
@assert nsrcs(graph) == nsrcs(solvable_graph) == dummy_eq
473+
476474
@label FOUND_DUMMY_EQ
477475
var_to_diff[v_t] = var_to_diff[dv]
478476
var_eq_matching[dv] = unassigned
479477
eq_var_matching[dummy_eq] = dv
480478
end
481479
end
482480

481+
function add_solvable_variable!()
482+
483+
end
484+
485+
function add_solvable_equation!()
486+
487+
end
488+
483489
"""
484490
Solve the solvable equations of the system and generate differential (or discrete)
485491
equations in terms of the selected states.
@@ -492,21 +498,27 @@ such that the mass matrix is:
492498
0 0].
493499
494500
Update the state to account for the new ordering and equations.
501+
502+
####### DISCRETE CASE
503+
- only substitute Shift(t, -2)
495504
"""
496-
function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching; simplify = false)
505+
function solve_and_generate_equations!(state::TearingState, neweqs, var_eq_matching; simplify = false, shift_sub = Dict())
497506
@unpack fullvars, sys, structure = state
498507
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
499508
eq_var_matching = invview(var_eq_matching)
500509
diff_to_var = invview(var_to_diff)
510+
dx_sub = Dict()
501511

502512
if ModelingToolkit.has_iv(sys)
503513
iv = get_iv(sys)
504514
if is_only_discrete(structure)
505515
D = Shift(iv, 1)
506516
lower_name = lower_varname_withshift
517+
total_sub = shift_sub
507518
else
508519
D = Differential(iv)
509520
lower_name = lower_varname_with_unit
521+
total_sub = dx_sub
510522
end
511523
else
512524
iv = D = nothing
@@ -540,6 +552,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
540552
# fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
541553
# is solved to give the RHS.
542554
for ieq in eqs
555+
println()
543556
iv = eq_var_matching[ieq]
544557
if is_solvable(ieq, iv)
545558
# We don't solve differential equations, but we will need to try to
@@ -549,7 +562,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
549562
isnothing(D) &&
550563
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
551564
order, lv = var_order(diff_to_var, iv)
552-
dx = D(fullvars[lv])
565+
@show fullvars[lv]
566+
@show simplify_shifts(fullvars[lv])
567+
dx = D(simplify_shifts(fullvars[lv]))
553568
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
554569
Symbolics.symbolic_linear_solve(neweqs[ieq],
555570
fullvars[iv]),
@@ -560,6 +575,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
560575
end
561576
push!(diff_eqs, eq)
562577
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
578+
dx_sub[simplify_shifts(eq.lhs)] = eq.rhs
579+
@show total_sub
580+
@show eq
563581
push!(diffeq_idxs, ieq)
564582
push!(diff_vars, diff_to_var[iv])
565583
continue
@@ -575,10 +593,10 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
575593
@warn "Tearing: solving $eq for $var is singular!"
576594
else
577595
rhs = -b / a
578-
neweq = var ~ Symbolics.fixpoint_sub(
596+
neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
579597
simplify ?
580598
Symbolics.simplify(rhs) : rhs,
581-
total_sub; operator = ModelingToolkit.Shift)
599+
dx_sub; operator = ModelingToolkit.Shift))
582600
push!(subeqs, neweq)
583601
push!(solved_equations, ieq)
584602
push!(solved_variables, iv)
@@ -589,7 +607,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
589607
if !(eq.lhs isa Number && eq.lhs == 0)
590608
rhs = eq.rhs - eq.lhs
591609
end
592-
push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub))
610+
push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
593611
push!(algeeq_idxs, ieq)
594612
end
595613
end
@@ -676,16 +694,19 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
676694
end
677695
neweqs = collect(equations(state))
678696
diff_to_var = invview(var_to_diff)
679-
total_sub = Dict()
680-
dummy_sub = Dict()
681697
is_discrete = is_only_discrete(state.structure)
682698

699+
shift_sub = Dict()
700+
683701
# Structural simplification
702+
dummy_sub = Dict()
684703
substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
685-
generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching;
686-
is_discrete, mm)
704+
705+
generate_derivative_variables!(state, neweqs, var_eq_matching;
706+
is_discrete, mm, shift_sub)
707+
687708
new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph =
688-
solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching; simplify)
709+
solve_and_generate_equations!(state, neweqs, var_eq_matching; simplify, shift_sub)
689710

690711
# Update system
691712
var_to_diff = new_var_to_diff
@@ -701,12 +722,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
701722
i -> (!isempty(𝑑neighbors(graph, i)) ||
702723
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
703724
end
704-
@show ispresent.(collect(1:length(fullvars)))
705-
@show 𝑑neighbors(graph, 5)
706-
@show var_to_diff[5]
707725

708-
@show neweqs
709-
@show fullvars
710726
sys = state.sys
711727
obs_sub = dummy_sub
712728
for eq in neweqs

src/structural_transformation/utils.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,9 @@ end
449449
### Misc
450450
###
451451

452-
function lower_varname_withshift(var, iv, backshift; unshifted = nothing)
453-
backshift < 0 && return Shift(iv, -backshift)(var)
454-
backshift == 0 && return unshifted
455-
ds = "$iv-$backshift"
452+
function lower_varname_withshift(var, iv, backshift; unshifted = nothing, allow_zero = true)
453+
backshift <= 0 && return Shift(iv, -backshift)(unshifted, allow_zero)
454+
ds = backshift > 0 ? "$iv-$backshift" : "$iv+$(-backshift)"
456455
d_separator = 'ˍ'
457456

458457
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
@@ -475,9 +474,15 @@ function isdoubleshift(var)
475474
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
476475
end
477476

477+
### Rules
478+
# 1. x(t) -> x(t)
479+
# 2. Shift(t, 0)(x(t)) -> x(t)
480+
# 3. Shift(t, 1)(x + z) -> Shift(t, 1)(x) + Shift(t, 1)(z)
481+
478482
function simplify_shifts(var)
479483
ModelingToolkit.hasshift(var) || return var
480484
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
485+
((op = operation(var)) isa Shift) && op.steps == 0 && return simplify_shifts(arguments(var)[1])
481486
if isdoubleshift(var)
482487
op1 = operation(var)
483488
vv1 = arguments(var)[1]

0 commit comments

Comments
 (0)