Skip to content

Commit 240ab21

Browse files
committed
fix: properly rename variables inside generate_system_equations
1 parent ed143a5 commit 240ab21

File tree

3 files changed

+119
-85
lines changed

3 files changed

+119
-85
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 97 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ called dummy derivatives.
248248
State selection is done. All non-differentiated variables are algebraic
249249
variables, and all variables that appear differentiated are differential variables.
250250
"""
251-
function substitute_derivatives_algevars!(ts::TearingState, neweqs, dummy_sub, var_eq_matching)
251+
function substitute_derivatives_algevars!(ts::TearingState, neweqs, var_eq_matching, dummy_sub)
252252
@unpack fullvars, sys, structure = ts
253253
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
254254
diff_to_var = invview(var_to_diff)
@@ -353,30 +353,32 @@ Effects on the system structure:
353353
- solvable_graph:
354354
- var_eq_matching: match D(x) to the added identity equation
355355
"""
356-
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching;
357-
is_discrete = false, mm = nothing)
356+
function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matching; mm = nothing)
358357
@unpack fullvars, sys, structure = ts
359358
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
360359
eq_var_matching = invview(var_eq_matching)
361360
diff_to_var = invview(var_to_diff)
362361
iv = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
363-
lower_name = is_discrete ? lower_varname_withshift : lower_varname_with_unit
364-
362+
is_discrete = is_only_discrete(structure)
363+
lower_varname = is_discrete ? lower_shift_varname : lower_varname_with_unit
365364
linear_eqs = mm === nothing ? Dict{Int, Int}() :
366365
Dict(reverse(en) for en in enumerate(mm.nzrows))
367366

368-
# Generate new derivative variables for all unsolved variables that have a derivative in the system
367+
# For variable x, make dummy derivative x_t if the
368+
# derivative is in the system
369369
for v in 1:length(var_to_diff)
370-
# Check if a derivative 1) exists and 2) is unsolved for
371370
dv = var_to_diff[v]
371+
# For discrete systems, directly substitute lowest-order variable
372+
if is_discrete && diff_to_var[v] == nothing
373+
fullvars[v] = lower_varname(fullvars[v], iv)
374+
end
372375
dv isa Int || continue
373376
solved = var_eq_matching[dv] isa Int
374377
solved && continue
375378

376379
# If there's `D(x) = x_t` already, update mappings and continue without
377380
# adding new equations/variables
378-
dd = find_duplicate_dd(dv, lineareqs, mm)
379-
381+
dd = find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
380382
if !isnothing(dd)
381383
dummy_eq, v_t = dd
382384
var_to_diff[v_t] = var_to_diff[dv]
@@ -386,26 +388,25 @@ function generate_derivative_variables!(ts::TearingState, neweqs, var_eq_matchin
386388
end
387389

388390
dx = fullvars[dv]
389-
order, lv = var_order(diff_to_var, dv)
390-
x_t = is_discrete ? lower_name(fullvars[lv], iv)
391-
: lower_name(fullvars[lv], iv, order)
392-
391+
order, lv = var_order(dv, diff_to_var)
392+
x_t = is_discrete ? lower_varname(fullvars[dv], iv) : lower_varname(fullvars[lv], iv, order)
393+
393394
# Add `x_t` to the graph
394-
add_dd_variable!(structure, x_t, dv)
395+
v_t = add_dd_variable!(structure, fullvars, x_t, dv)
395396
# Add `D(x) - x_t ~ 0` to the graph
396-
add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv)
397+
dummy_eq = add_dd_equation!(structure, neweqs, 0 ~ dx - x_t, dv, v_t)
397398

398399
# Update matching
399400
push!(var_eq_matching, unassigned)
400401
var_eq_matching[dv] = unassigned
401-
eq_var_matching[dummy_eq] = dv
402+
eq_var_matching[dummy_eq] = dv
402403
end
403404
end
404405

405406
"""
406-
Check if there's `D(x) = x_t` already.
407+
Check if there's `D(x) = x_t` already.
407408
"""
408-
function find_duplicate_dd(dv, lineareqs, mm)
409+
function find_duplicate_dd(dv, solvable_graph, linear_eqs, mm)
409410
for eq in 𝑑neighbors(solvable_graph, dv)
410411
mi = get(linear_eqs, eq, 0)
411412
iszero(mi) && continue
@@ -424,28 +425,28 @@ function find_duplicate_dd(dv, lineareqs, mm)
424425
return nothing
425426
end
426427

427-
function add_dd_variable!(s::SystemStructure, x_t, dv)
428-
push!(s.fullvars, simplify_shifts(x_t))
428+
function add_dd_variable!(s::SystemStructure, fullvars, x_t, dv)
429+
push!(fullvars, simplify_shifts(x_t))
430+
v_t = length(fullvars)
429431
v_t_idx = add_vertex!(s.var_to_diff)
430-
@assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
431-
length(var_eq_matching)
432432
add_vertex!(s.graph, DST)
433433
# TODO: do we care about solvable_graph? We don't use them after
434434
# `dummy_derivative_graph`.
435435
add_vertex!(s.solvable_graph, DST)
436-
var_to_diff[v_t] = var_to_diff[dv]
436+
s.var_to_diff[v_t] = s.var_to_diff[dv]
437+
v_t
437438
end
438439

439440
# dv = index of D(x), v_t = index of x_t
440-
function add_dd_equation!(s::SystemStructure, neweqs, eq, dv)
441+
function add_dd_equation!(s::SystemStructure, neweqs, eq, dv, v_t)
441442
push!(neweqs, eq)
442443
add_vertex!(s.graph, SRC)
443-
v_t = length(s.fullvars)
444444
dummy_eq = length(neweqs)
445445
add_edge!(s.graph, dummy_eq, dv)
446446
add_edge!(s.graph, dummy_eq, v_t)
447447
add_vertex!(s.solvable_graph, SRC)
448448
add_edge!(s.solvable_graph, dummy_eq, dv)
449+
dummy_eq
449450
end
450451

451452
"""
@@ -463,6 +464,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
463464
iv = get_iv(sys)
464465
if is_only_discrete(structure)
465466
D = Shift(iv, 1)
467+
for v in fullvars
468+
op = operation(v)
469+
op isa Shift && (op.steps < 0) && (total_sub[v] = lower_shift_varname(v, iv))
470+
end
466471
else
467472
D = Differential(iv)
468473
end
@@ -493,24 +498,40 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
493498
eqs = Iterators.reverse(toporder)
494499
idep = iv
495500

496-
# Generate differential equations.
497-
# fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
498-
# is solved to give the RHS.
501+
# Generate equations.
502+
# Solvable equations of differential variables D(x) become differential equations
503+
# Solvable equations of non-differential variables become observable equations
504+
# Non-solvable equations become algebraic equations.
499505
for ieq in eqs
500506
iv = eq_var_matching[ieq]
501-
if is_solvable(ieq, iv)
502-
if isdervar(iv)
503-
isnothing(D) &&
504-
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
505-
add_differential_equation!(structure, iv, neweqs, ieq,
506-
diff_vars, diff_eqs, diffeq_idxs, total_sub)
507-
else
508-
add_solved_equation!(structure, iv, neweqs, ieq,
509-
solved_vars, solved_eqs, solvedeq_idxs, total_sub)
507+
var = fullvars[iv]
508+
eq = neweqs[ieq]
509+
510+
if is_solvable(ieq, iv) && isdervar(iv)
511+
isnothing(D) && throw(UnexpectedDifferentialError(equations(sys)[ieq]))
512+
order, lv = var_order(iv, diff_to_var)
513+
dx = D(simplify_shifts(fullvars[lv]))
514+
515+
neweq = make_differential_equation(var, dx, eq, total_sub)
516+
for e in 𝑑neighbors(graph, iv)
517+
e == ieq && continue
518+
rem_edge!(graph, e, iv)
519+
end
520+
521+
push!(diff_eqs, neweq)
522+
push!(diffeq_idxs, ieq)
523+
push!(diff_vars, diff_to_var[iv])
524+
elseif is_solvable(ieq, iv)
525+
neweq = make_solved_equation(var, eq, total_sub; simplify)
526+
!isnothing(neweq) && begin
527+
push!(solved_eqs, neweq)
528+
push!(solvedeq_idxs, ieq)
529+
push!(solved_vars, iv)
510530
end
511531
else
512-
add_algebraic_equation!(structure, neweqs, ieq,
513-
alge_eqs, algeeq_idxs, total_sub)
532+
neweq = make_algebraic_equation(var, eq, total_sub)
533+
push!(alge_eqs, neweq)
534+
push!(algeeq_idxs, ieq)
514535
end
515536
end
516537

@@ -529,55 +550,43 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching
529550
return neweqs, solved_eqs, eq_ordering, var_ordering, length(solved_vars), length(solved_vars_set)
530551
end
531552

532-
function add_differential_equation!(s::SystemStructure, iv, neweqs, ieq, diff_vars, diff_eqs, diffeqs_idxs, total_sub)
533-
diff_to_var = invview(s.var_to_diff)
553+
struct UnexpectedDifferentialError
554+
eq::Equation
555+
end
534556

535-
order, lv = var_order(diff_to_var, iv)
536-
dx = D(simplify_shifts(fullvars[lv]))
537-
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
538-
Symbolics.symbolic_linear_solve(neweqs[ieq],
539-
fullvars[iv]),
540-
total_sub; operator = ModelingToolkit.Shift))
541-
for e in 𝑑neighbors(s.graph, iv)
542-
e == ieq && continue
543-
rem_edge!(s.graph, e, iv)
544-
end
557+
function Base.showerror(io::IO, err::UnexpectedDifferentialError)
558+
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(err.eq)")
559+
end
545560

546-
push!(diff_eqs, eq)
547-
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
548-
push!(diffeq_idxs, ieq)
549-
push!(diff_vars, diff_to_var[iv])
561+
function make_differential_equation(var, dx, eq, total_sub)
562+
dx ~ simplify_shifts(Symbolics.fixpoint_sub(
563+
Symbolics.symbolic_linear_solve(eq, var),
564+
total_sub; operator = ModelingToolkit.Shift))
550565
end
551566

552-
function add_algebraic_equation!(s::SystemStructure, neweqs, ieq, alge_eqs, algeeq_idxs, total_sub)
553-
eq = neweqs[ieq]
567+
function make_algebraic_equation(var, eq, total_sub)
554568
rhs = eq.rhs
555569
if !(eq.lhs isa Number && eq.lhs == 0)
556570
rhs = eq.rhs - eq.lhs
557571
end
558-
push!(alge_eqs, 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)))
559-
push!(algeeq_idxs, ieq)
572+
0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub))
560573
end
561574

562-
function add_solved_equation!(s::SystemStructure, iv, neweqs, ieq, solved_vars, solved_eqs, solvedeq_idxs, total_sub)
563-
eq = neweqs[ieq]
564-
var = fullvars[iv]
575+
function make_solved_equation(var, eq, total_sub; simplify = false)
565576
residual = eq.lhs - eq.rhs
566577
a, b, islinear = linear_expansion(residual, var)
567578
@assert islinear
568579
# 0 ~ a * var + b
569580
# var ~ -b/a
570581
if ModelingToolkit._iszero(a)
571582
@warn "Tearing: solving $eq for $var is singular!"
583+
return nothing
572584
else
573585
rhs = -b / a
574-
neweq = var ~ simplify_shifts(Symbolics.fixpoint_sub(
586+
return var ~ simplify_shifts(Symbolics.fixpoint_sub(
575587
simplify ?
576588
Symbolics.simplify(rhs) : rhs,
577589
total_sub; operator = ModelingToolkit.Shift))
578-
push!(solved_eqs, neweq)
579-
push!(solvedeq_idxs, ieq)
580-
push!(solved_vars, iv)
581590
end
582591
end
583592

@@ -592,8 +601,8 @@ such that the mass matrix is:
592601
Update the state to account for the new ordering and equations.
593602
"""
594603
# TODO: BLT sorting
595-
function reorder_vars!(s::SystemStructure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
596-
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
604+
function reorder_vars!(state::TearingState, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
605+
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
597606

598607
eqsperm = zeros(Int, nsrcs(graph))
599608
for (i, v) in enumerate(eq_ordering)
@@ -623,20 +632,26 @@ function reorder_vars!(s::SystemStructure, var_eq_matching, eq_ordering, var_ord
623632
d′ = eqsperm[d]
624633
new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
625634
end
626-
new_fullvars = fullvars[invvarsperm]
635+
new_fullvars = state.fullvars[var_ordering]
627636

637+
@show new_graph
638+
@show new_var_to_diff
628639
# Update system structure
629640
@set! state.structure.graph = complete(new_graph)
630641
@set! state.structure.var_to_diff = new_var_to_diff
631642
@set! state.structure.eq_to_diff = new_eq_to_diff
632643
@set! state.fullvars = new_fullvars
644+
state
633645
end
634646

635647
"""
636648
Set the system equations, unknowns, observables post-tearing.
637649
"""
638-
function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
650+
function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
651+
cse_hack = true, array_hack = true)
639652
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
653+
@show graph
654+
@show var_to_diff
640655
diff_to_var = invview(var_to_diff)
641656

642657
ispresent = let var_to_diff = var_to_diff, graph = graph
@@ -656,7 +671,12 @@ function update_simplified_system!(state::TearingState, neweqs, solved_eqs, dumm
656671
unknowns = Any[v
657672
for (i, v) in enumerate(state.fullvars)
658673
if diff_to_var[i] === nothing && ispresent(i)]
674+
@show unknowns
675+
@show state.fullvars
676+
@show 𝑑neighbors(graph, 5)
677+
@show neweqs
659678
unknowns = [unknowns; extra_unknowns]
679+
@show unknowns
660680
@set! sys.unknowns = unknowns
661681

662682
obs, subeqs, deps = cse_and_array_hacks(
@@ -684,7 +704,7 @@ end
684704
# differential variables.
685705

686706
# Give the order of the variable indexed by dv
687-
function var_order(diff_to_var, dv)
707+
function var_order(dv, diff_to_var)
688708
order = 0
689709
while (dv′ = diff_to_var[dv]) !== nothing
690710
order += 1
@@ -695,7 +715,6 @@ end
695715

696716
function tearing_reassemble(state::TearingState, var_eq_matching,
697717
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
698-
699718
extra_vars = Int[]
700719
if full_var_eq_matching !== nothing
701720
for v in 𝑑vertices(state.structure.graph)
@@ -704,21 +723,22 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
704723
push!(extra_vars, v)
705724
end
706725
end
707-
extra_unknowns = fullvars[extra_vars]
726+
extra_unknowns = state.fullvars[extra_vars]
708727
neweqs = collect(equations(state))
728+
dummy_sub = Dict()
709729

710730
# Structural simplification
711-
dummy_sub = Dict()
712-
substitute_derivatives_algevars!(state, neweqs, dummy_sub, var_eq_matching)
731+
substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub)
713732

714733
generate_derivative_variables!(state, neweqs, var_eq_matching; mm)
715734

716735
neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var =
717736
generate_system_equations!(state, neweqs, var_eq_matching; simplify)
718737

719-
reorder_vars!(state.structure, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
738+
state = reorder_vars!(state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
739+
740+
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns; cse_hack, array_hack)
720741

721-
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns)
722742
@set! state.sys = sys
723743
@set! sys.tearing_state = state
724744
return invalidate_cache!(sys)

src/structural_transformation/utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,10 @@ end
452452
### Misc
453453
###
454454

455-
function lower_varname_withshift(var, iv)
455+
# For discrete variables. Turn Shift(t, k)(x(t)) into xₜ₋ₖ(t)
456+
function lower_shift_varname(var, iv)
456457
op = operation(var)
457-
op isa Shift || return var
458+
op isa Shift || return Shift(iv, 0)(var, true) # hack to prevent simplification of x(t) - x(t)
458459
backshift = op.steps
459460
backshift > 0 && return var
460461

@@ -476,6 +477,14 @@ function lower_varname_withshift(var, iv)
476477
return ModelingToolkit._with_unit(identity, newvar, iv)
477478
end
478479

480+
function lower_varname(var, iv, order; is_discrete = false)
481+
if is_discrete
482+
lower_shift_varname(var, iv)
483+
else
484+
lower_varname_with_unit(var, iv, order)
485+
end
486+
end
487+
479488
function isdoubleshift(var)
480489
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
481490
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
@@ -484,6 +493,7 @@ end
484493
function simplify_shifts(var)
485494
ModelingToolkit.hasshift(var) || return var
486495
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
496+
(op = operation(var)) isa Shift && op.steps == 0 && return first(arguments(var))
487497
if isdoubleshift(var)
488498
op1 = operation(var)
489499
vv1 = arguments(var)[1]

0 commit comments

Comments
 (0)