Skip to content

Commit 456e5e2

Browse files
committed
WIP: Reorder equations and variables
1 parent b44667c commit 456e5e2

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function substitute_vars!(graph::BipartiteGraph, subs, cache = Int[], callback!
134134
graph
135135
end
136136

137-
function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F) where F
137+
function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F, var_to_diff) where F
138138
eq = neweqs[ieq]
139139
if !(eq.lhs isa Number && eq.lhs == 0)
140140
eq = 0 ~ eq.rhs - eq.lhs
@@ -153,16 +153,16 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F) where F
153153
dervar = var
154154
end
155155
end
156-
dervar === nothing && return 0 ~ rhs
156+
dervar === nothing && return (0 ~ rhs), dervar
157157
new_lhs = var = fullvars[dervar]
158158
# 0 ~ a * D(x) + b
159159
# D(x) ~ -b/a
160160
a, b, islinear = linear_expansion(rhs, var)
161161
if !islinear
162-
return 0 ~ rhs
162+
return (0 ~ rhs), nothing
163163
end
164164
new_rhs = -b / a
165-
return new_lhs ~ new_rhs
165+
return (new_lhs ~ new_rhs), invview(var_to_diff)[dervar]
166166
else # a number
167167
if abs(rhs) > 100eps(float(rhs))
168168
@warn "The equation $eq is not consistent. It simplifed to 0 == $rhs."
@@ -423,24 +423,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
423423
empty!(subs)
424424
end
425425

426-
diffeq_idxs = BitSet()
427-
final_eqs = Equation[]
428-
var_rename = zeros(Int, length(var_eq_matching))
426+
diffeq_idxs = Int[]
427+
algeeq_idxs = Int[]
428+
diff_eqs = Equation[]
429+
alge_eqs = Equation[]
430+
diff_vars = Int[]
429431
subeqs = Equation[]
430432
solved_equations = Int[]
431433
solved_variables = Int[]
432434
idx = 0
433435
# Solve solvable equations
434-
for (iv, ieq) in enumerate(var_eq_matching)
435-
if is_solvable(ieq, iv)
436-
if isdervar(iv)
437-
var_rename[iv] = (idx += 1)
438-
end
439-
var_rename[iv] = -1
440-
else
441-
var_rename[iv] = (idx += 1)
442-
end
443-
end
444436
neqs = nsrcs(graph)
445437
for (ieq, iv) in enumerate(invview(var_eq_matching))
446438
ieq > neqs && break
@@ -450,14 +442,17 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
450442
# We cannot solve the differential variable like D(x)
451443
if isdervar(iv)
452444
# TODO: what if `to_mass_matrix_form(ieq)` returns `nothing`?
453-
push!(final_eqs, to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar))
445+
eq, diffidx = to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar, var_to_diff)
446+
push!(diff_eqs, eq)
454447
push!(diffeq_idxs, ieq)
448+
push!(diff_vars, diffidx)
455449
continue
456450
end
457451
eq = neweqs[ieq]
458452
var = fullvars[iv]
459453
residual = eq.lhs - eq.rhs
460454
a, b, islinear = linear_expansion(residual, var)
455+
@assert islinear
461456
# 0 ~ a * var + b
462457
# var ~ -b/a
463458
if ModelingToolkit._iszero(a)
@@ -471,11 +466,30 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
471466
push!(solved_variables, iv)
472467
end
473468
else
474-
push!(final_eqs, to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar))
469+
eq, diffidx = to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar, var_to_diff)
470+
if diffidx === nothing
471+
push!(alge_eqs, eq)
472+
push!(algeeq_idxs, ieq)
473+
else
474+
push!(diff_eqs, eq)
475+
push!(diffeq_idxs, ieq)
476+
push!(diff_vars, diffidx)
477+
end
475478
end
476479
end
477480
# TODO: BLT sorting
478-
neweqs = final_eqs
481+
neweqs = [diff_eqs; alge_eqs]
482+
eqsperm = [diffeq_idxs; algeeq_idxs]
483+
diff_vars_set = BitSet(diff_vars)
484+
if length(diff_vars_set) != length(diff_vars)
485+
error("Tearing internal error: lowering DAE into semi-implicit ODE failed!")
486+
end
487+
invvarsperm = [diff_vars; setdiff(setdiff(1:ndsts(graph), diff_vars_set), BitSet(solved_variables))]
488+
varsperm = zeros(Int, ndsts(graph))
489+
for (i, v) in enumerate(invvarsperm)
490+
varsperm[v] = i
491+
end
492+
@show varsperm
479493

480494
if isempty(solved_equations)
481495
deps = Vector{Int}[]
@@ -493,31 +507,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
493507

494508
# Contract the vertices in the structure graph to make the structure match
495509
# the new reality of the system we've just created.
496-
#
497-
# TODO: fix ordering and remove equations
498-
graph = contract_variables(graph, var_eq_matching, solved_variables)
510+
graph = contract_variables(graph, var_eq_matching, varsperm, solved_variables, eqsperm)
499511

500512
# Update system
501-
solved_variables_set = BitSet(solved_variables)
502-
active_vars = setdiff!(setdiff!(BitSet(1:length(fullvars)), solved_variables_set),
503-
removed_vars)
504-
new_var_to_diff = complete(DiffGraph(length(active_vars)))
513+
new_var_to_diff = complete(DiffGraph(length(invvarsperm)))
505514
idx = 0
506515
for (v, d) in enumerate(var_to_diff)
507-
v′ = var_rename[v]
516+
v′ = varsperm[v]
508517
(v′ > 0 && d !== nothing) || continue
509-
d′ = var_rename[d]
518+
d′ = varsperm[d]
510519
new_var_to_diff[v′] = d′ > 0 ? d′ : nothing
511520
end
521+
var_to_diff = new_var_to_diff
522+
diff_to_var = invview(var_to_diff)
512523

513524
@set! state.structure.graph = graph
514525
# Note that `eq_to_diff` is not updated
515-
@set! state.structure.var_to_diff = new_var_to_diff
516-
@set! state.fullvars = [v for (i, v) in enumerate(fullvars) if i in active_vars]
526+
@set! state.structure.var_to_diff = var_to_diff
527+
@set! state.fullvars = fullvars = fullvars[invvarsperm]
517528

518529
sys = state.sys
519530
@set! sys.eqs = neweqs
520-
@set! sys.states = [fullvars[i] for i in active_vars if diff_to_var[i] === nothing]
531+
@set! sys.states = [v for (i, v) in enumerate(fullvars) if diff_to_var[i] === nothing]
521532
deleteat!(oldobs, sort!(removed_obs))
522533
@set! sys.observed = [oldobs; subeqs]
523534
@set! sys.substitutions = Substitutions(subeqs, deps)

src/structural_transformation/tearing.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@ function masked_cumsum!(A::Vector)
2222
end
2323

2424
function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching,
25-
eliminated_variables)
26-
var_rename = ones(Int64, ndsts(graph))
25+
var_rename, eliminated_variables, eqsperm = nothing)
2726
eq_rename = ones(Int64, nsrcs(graph))
2827
for v in eliminated_variables
2928
eq_rename[var_eq_matching[v]] = 0
3029
var_rename[v] = 0
3130
end
32-
masked_cumsum!(var_rename)
3331
masked_cumsum!(eq_rename)
3432

3533
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
@@ -45,6 +43,9 @@ function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching,
4543
for e in 𝑠vertices(graph)
4644
ne = eq_rename[e]
4745
ne == 0 && continue
46+
if eqsperm !== nothing
47+
ne = eq_rename[eqsperm[ne]]
48+
end
4849
for v in 𝑠neighbors(graph, e)
4950
newvar = var_rename[v]
5051
if newvar != 0

src/structural_transformation/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ end
113113
function sorted_incidence_matrix(ts::TransformationState, val = true; only_algeqs = false,
114114
only_algvars = false)
115115
var_eq_matching, var_scc = algebraic_variables_scc(ts)
116-
fullvars = ts.fullvars
117116
s = ts.structure
118117
graph = ts.structure.graph
119118
varsmap = zeros(Int, ndsts(graph))

0 commit comments

Comments
 (0)