Skip to content

Commit 91acf91

Browse files
committed
beginning implicit equation
1 parent 5a77f43 commit 91acf91

File tree

2 files changed

+37
-40
lines changed

2 files changed

+37
-40
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,8 @@ by iteratively substituting x_t ~ D(x), then x_tt ~ D(x_t). For this reason the
360360
total_sub dict is updated at the time that the renamed variables are written,
361361
inside the loop where new variables are generated.
362362
"""
363-
function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching, var_order;
363+
function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var_eq_matching;
364364
is_discrete = false, mm = nothing)
365-
366365
@unpack fullvars, sys, structure = ts
367366
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
368367
eq_var_matching = invview(var_eq_matching)
@@ -386,13 +385,14 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
386385

387386
# v is the index of the current variable, x = fullvars[v]
388387
# dv is the index of the derivative dx = D(x), x_t is the substituted variable
389-
#
390388
# For ODESystems: lv is the index of the lowest-order variable (x(t))
391389
# For DiscreteSystems:
392390
# - lv is the index of the lowest-order variable (Shift(t, k)(x(t)))
393391
# - uv is the index of the highest-order variable (x(t))
394392
for v in 1:length(var_to_diff)
395393
dv = var_to_diff[v]
394+
println()
395+
@show (v, dv)
396396
if is_discrete
397397
x = fullvars[v]
398398
op = operation(x)
@@ -405,6 +405,9 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
405405
end
406406
dv isa Int || continue
407407

408+
@show dv
409+
@show var_eq_matching[dv]
410+
@show fullvars
408411
solved = var_eq_matching[dv] isa Int
409412
solved && continue
410413

@@ -429,9 +432,10 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
429432

430433
dx = fullvars[dv]
431434
# add `x_t`
432-
order, lv = var_order(dv)
435+
order, lv = var_order(diff_to_var, dv)
433436
x_t = is_discrete ? lower_name(fullvars[lv], iv, -low-order-1; unshifted = fullvars[uv]) :
434437
lower_name(fullvars[lv], iv, order)
438+
@show dx, x_t
435439
push!(fullvars, simplify_shifts(x_t))
436440
v_t = length(fullvars)
437441
v_t_idx = add_vertex!(var_to_diff)
@@ -443,23 +447,23 @@ function generate_derivative_variables!(ts::TearingState, neweqs, total_sub, var
443447
@assert v_t_idx == ndsts(graph) == ndsts(solvable_graph) == length(fullvars) ==
444448
length(var_eq_matching)
445449

446-
# Add the substitutions to total_sub directly.
450+
# Add discrete substitutions to total_sub directly.
447451
is_discrete && begin
448452
idx_to_lowest_shift[v_t] = idx_to_lowest_shift[dv]
449-
@show dx
450453
if operation(dx) isa Shift
451454
total_sub[dx] = x_t
452455
for e in 𝑑neighbors(graph, dv)
453456
add_edge!(graph, e, v_t)
454457
rem_edge!(graph, e, dv)
455458
end
456-
@show graph
459+
# Do not add the lowest-order substitution as an equation, just substitute
457460
!(operation(x) isa Shift) && begin
458461
var_to_diff[v_t] = var_to_diff[dv]
459462
continue
460463
end
461464
end
462465
end
466+
463467
# add `D(x) - x_t ~ 0`
464468
push!(neweqs, 0 ~ x_t - dx)
465469
add_vertex!(graph, SRC)
@@ -489,7 +493,7 @@ such that the mass matrix is:
489493
490494
Update the state to account for the new ordering and equations.
491495
"""
492-
function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching, var_order; simplify = false)
496+
function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, var_eq_matching; simplify = false)
493497
@unpack fullvars, sys, structure = state
494498
@unpack solvable_graph, var_to_diff, eq_to_diff, graph, lowest_shift = structure
495499
eq_var_matching = invview(var_eq_matching)
@@ -530,13 +534,11 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
530534

531535
toporder = topological_sort(DiCMOBiGraph{false}(graph, var_eq_matching))
532536
eqs = Iterators.reverse(toporder)
533-
idep = ModelingToolkit.has_iv(sys) ? ModelingToolkit.get_iv(sys) : nothing
534-
535-
@show eq_var_matching
536-
@show fullvars
537-
@show neweqs
537+
idep = iv
538538

539-
# Equation ieq is solved for the RHS of iv
539+
# Generate differential equations.
540+
# fullvars[iv] is a differential variable of the form D^n(x), and neweqs[ieq]
541+
# is solved to give the RHS.
540542
for ieq in eqs
541543
iv = eq_var_matching[ieq]
542544
if is_solvable(ieq, iv)
@@ -546,7 +548,7 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
546548
if isdervar(iv)
547549
isnothing(D) &&
548550
error("Differential found in a non-differential system. Likely this is a bug in the construction of an initialization system. Please report this issue with a reproducible example. Offending equation: $(equations(sys)[ieq])")
549-
order, lv = var_order(iv)
551+
order, lv = var_order(diff_to_var, iv)
550552
dx = D(fullvars[lv])
551553
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
552554
Symbolics.symbolic_linear_solve(neweqs[ieq],
@@ -634,8 +636,9 @@ function solve_and_generate_equations!(state::TearingState, neweqs, total_sub, v
634636
d′ = eqsperm[d]
635637
new_eq_to_diff[v′] = d′ > 0 ? d′ : nothing
636638
end
639+
new_fullvars = fullvars[invvarsperm]
637640

638-
fullvars[invvarsperm], new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
641+
new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph
639642
end
640643

641644
# Terminology and Definition:
@@ -649,6 +652,16 @@ end
649652

650653
import ModelingToolkit: Shift
651654

655+
# Give the order of the variable indexed by dv
656+
function var_order(diff_to_var, dv)
657+
order = 0
658+
while (dv′ = diff_to_var[dv]) !== nothing
659+
order += 1
660+
dv = dv′
661+
end
662+
order, dv
663+
end
664+
652665
function tearing_reassemble(state::TearingState, var_eq_matching,
653666
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
654667
@unpack fullvars, sys, structure = state
@@ -667,22 +680,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
667680
dummy_sub = Dict()
668681
is_discrete = is_only_discrete(state.structure)
669682

670-
var_order = let diff_to_var = diff_to_var
671-
dv -> begin
672-
order = 0
673-
while (dv′ = diff_to_var[dv]) !== nothing
674-
order += 1
675-
dv = dv′
676-
end
677-
order, dv
678-
end
679-
end
680-
681683
# Structural simplification
682684
substitute_dummy_derivatives!(state, neweqs, dummy_sub, var_eq_matching)
683-
generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching, var_order;
685+
generate_derivative_variables!(state, neweqs, total_sub, var_eq_matching;
684686
is_discrete, mm)
685-
new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph = solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching, var_order; simplify)
687+
new_fullvars, new_var_to_diff, new_eq_to_diff, neweqs, subeqs, graph =
688+
solve_and_generate_equations!(state, neweqs, total_sub, var_eq_matching; simplify)
686689

687690
# Update system
688691
var_to_diff = new_var_to_diff
@@ -698,25 +701,18 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
698701
i -> (!isempty(𝑑neighbors(graph, i)) ||
699702
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
700703
end
701-
@show graph
702-
println()
703-
println("Shift test...")
704-
@show neweqs
705-
@show fullvars
704+
@show ispresent.(collect(1:length(fullvars)))
706705
@show 𝑑neighbors(graph, 5)
706+
@show var_to_diff[5]
707707

708+
@show neweqs
709+
@show fullvars
708710
sys = state.sys
709711
obs_sub = dummy_sub
710712
for eq in neweqs
711713
isdiffeq(eq) || continue
712714
obs_sub[eq.lhs] = eq.rhs
713715
end
714-
is_discrete && for eq in subeqs
715-
obs_sub[eq.rhs] = eq.lhs
716-
end
717-
718-
@show obs_sub
719-
@show observed(sys)
720716
# TODO: compute the dependency correctly so that we don't have to do this
721717
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
722718

src/structural_transformation/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ end
450450
###
451451

452452
function lower_varname_withshift(var, iv, backshift; unshifted = nothing)
453+
backshift < 0 && return Shift(iv, -backshift)(var)
453454
backshift == 0 && return unshifted
454455
ds = "$iv-$backshift"
455456
d_separator = 'ˍ'

0 commit comments

Comments
 (0)