Skip to content

Commit e3fdefe

Browse files
authored
Merge pull request #1872 from SciML/myb/one_tearing_state
Update structural info after alias elimination
2 parents 17432c8 + 2674a15 commit e3fdefe

File tree

6 files changed

+179
-52
lines changed

6 files changed

+179
-52
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
248248
end
249249
end
250250
if diff_va !== nothing
251+
# differentiated alias
251252
n_dummys = length(dummy_derivatives)
252253
needed = count(x -> x isa Int, diff_to_eq) - n_dummys
253254
n = 0
@@ -265,36 +266,68 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
265266
@warn "The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
266267
end
267268
dummy_derivatives_set = BitSet(dummy_derivatives)
269+
270+
if ag !== nothing
271+
function isreducible(x)
272+
# `k` is reducible if all lower differentiated variables are.
273+
isred = true
274+
while isred
275+
if x in dummy_derivatives_set
276+
break
277+
end
278+
x = diff_to_var[x]
279+
x === nothing && break
280+
if !haskey(ag, x)
281+
isred = false
282+
end
283+
end
284+
isred
285+
end
286+
irreducible_set = BitSet()
287+
for (k, (_, v)) in ag
288+
isreducible(k) || push!(irreducible_set, k)
289+
isreducible(k) || push!(irreducible_set, k)
290+
push!(irreducible_set, v)
291+
end
292+
end
293+
294+
is_not_present = v -> isempty(𝑑neighbors(graph, v)) &&
295+
(ag === nothing || (haskey(ag, v) && !(v in irreducible_set)))
296+
# Derivatives that are either in the dummy derivatives set or ended up not
297+
# participating in the system at all are not considered differential
298+
is_some_diff = let dummy_derivatives_set = dummy_derivatives_set
299+
v -> !(v in dummy_derivatives_set) &&
300+
!(var_to_diff[v] === nothing && is_not_present(v))
301+
end
302+
303+
# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
304+
# actually differentiated variables.
305+
isdiffed = let diff_to_var = diff_to_var
306+
v -> diff_to_var[v] !== nothing && is_some_diff(v)
307+
end
308+
268309
# We can eliminate variables that are not a selected state (differential
269310
# variables). Selected states are differentiated variables that are not
270311
# dummy derivatives.
271-
can_eliminate = let var_to_diff = var_to_diff,
272-
dummy_derivatives_set = dummy_derivatives_set
273-
312+
can_eliminate = let var_to_diff = var_to_diff
274313
v -> begin
275314
if ag !== nothing
276315
haskey(ag, v) && return false
277316
end
278317
dv = var_to_diff[v]
279-
dv === nothing || dv in dummy_derivatives_set
318+
dv === nothing && return true
319+
is_some_diff(dv) || return true
320+
return false
280321
end
281322
end
282323

283-
# We don't want tearing to give us `y_t ~ D(y)`, so we skip equations with
284-
# actually differentiated variables.
285-
isdiffed = let diff_to_var = diff_to_var, dummy_derivatives_set = dummy_derivatives_set
286-
v -> diff_to_var[v] !== nothing && !(v in dummy_derivatives_set)
287-
end
288-
289324
var_eq_matching = tear_graph_modia(structure, isdiffed,
290325
Union{Unassigned, SelectedState};
291326
varfilter = can_eliminate)
292327
for v in eachindex(var_eq_matching)
293-
if ag !== nothing && haskey(ag, v) && iszero(ag[v][1])
294-
continue
295-
end
328+
is_not_present(v) && continue
296329
dv = var_to_diff[v]
297-
(dv === nothing || dv in dummy_derivatives_set) && continue
330+
(dv === nothing || !is_some_diff(dv)) && continue
298331
var_eq_matching[v] = SelectedState()
299332
end
300333

src/structural_transformation/symbolics_tearing.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ function check_diff_graph(var_to_diff, fullvars)
217217
end
218218
=#
219219

220-
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
220+
function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
221+
simplify = false)
221222
@unpack fullvars, sys, structure = state
222223
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
223224

@@ -318,6 +319,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
318319
var -> diff_to_var[var] !== nothing
319320
end
320321

322+
#retear = BitSet()
321323
# There are three cases where we want to generate new variables to convert
322324
# the system into first order (semi-implicit) ODEs.
323325
#
@@ -468,20 +470,38 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
468470
for (ogidx, dx_idx, x_t_idx) in Iterators.reverse(subinfo)
469471
# We need a loop here because both `D(D(x))` and `D(x_t)` need to be
470472
# substituted to `x_tt`.
471-
for idx in (ogidx, dx_idx)
473+
for idx in (ogidx == dx_idx ? ogidx : (ogidx, dx_idx))
472474
subidx = ((idx => x_t_idx),)
473475
# This handles case 2.2
474-
if var_eq_matching[idx] isa Int
475-
var_eq_matching[x_t_idx] = var_eq_matching[idx]
476-
end
477476
substitute_vars!(structure, subidx, idx_buffer, sub_callback!;
478477
exclude = order_lowering_eqs)
478+
if var_eq_matching[idx] isa Int
479+
original_assigned_eq = var_eq_matching[idx]
480+
# This removes the assignment of the variable `idx`, so we
481+
# should consider assign them again later.
482+
var_eq_matching[x_t_idx] = original_assigned_eq
483+
#if !isempty(𝑑neighbors(graph, idx))
484+
# push!(retear, idx)
485+
#end
486+
end
479487
end
480488
end
481489
empty!(subinfo)
482490
empty!(subs)
483491
end
484492

493+
#ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph, var_eq_matching); dir = :in)
494+
#for idx in retear
495+
# for alternative_eq in 𝑑neighbors(solvable_graph, idx)
496+
# # skip actually differentiated variables
497+
# any(𝑠neighbors(graph, alternative_eq)) do alternative_v
498+
# ((vv = diff_to_var[alternative_v]) !== nothing &&
499+
# var_eq_matching[vv] === SelectedState())
500+
# end && continue
501+
# try_assign_eq!(ict, idx, alternative_eq) && break
502+
# end
503+
#end
504+
485505
# Will reorder equations and states to be:
486506
# [diffeqs; ...]
487507
# [diffvars; ...]
@@ -555,9 +575,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
555575
if length(diff_vars_set) != length(diff_vars)
556576
error("Tearing internal error: lowering DAE into semi-implicit ODE failed!")
557577
end
578+
solved_variables_set = BitSet(solved_variables)
579+
ag === nothing || union!(solved_variables_set, keys(ag))
558580
invvarsperm = [diff_vars;
559581
setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
560-
BitSet(solved_variables))]
582+
solved_variables_set)]
561583
varsperm = zeros(Int, ndsts(graph))
562584
for (i, v) in enumerate(invvarsperm)
563585
varsperm[v] = i
@@ -580,7 +602,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
580602
# Contract the vertices in the structure graph to make the structure match
581603
# the new reality of the system we've just created.
582604
graph = contract_variables(graph, var_eq_matching, varsperm, eqsperm,
583-
length(solved_variables))
605+
length(solved_variables), length(solved_variables_set))
584606

585607
# Update system
586608
new_var_to_diff = complete(DiffGraph(length(invvarsperm)))
@@ -632,6 +654,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
632654
isdiffeq(eq) || continue
633655
obs_sub[eq.lhs] = eq.rhs
634656
end
657+
# TODO: compute the dependency correctly so that we don't have to do this
635658
obs = substitute.([oldobs; subeqs], (obs_sub,))
636659
@set! sys.observed = obs
637660
@set! state.sys = sys
@@ -688,7 +711,8 @@ end
688711
Perform index reduction and use the dummy derivative technique to ensure that
689712
the system is balanced.
690713
"""
691-
function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwargs...)
714+
function dummy_derivative(sys, state = TearingState(sys), ag = nothing; simplify = false,
715+
kwargs...)
692716
jac = let state = state
693717
(eqs, vars) -> begin
694718
symeqs = EquationsView(state)[eqs]
@@ -710,6 +734,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwar
710734
p
711735
end
712736
end
713-
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...)
714-
tearing_reassemble(state, var_eq_matching; simplify = simplify)
737+
var_eq_matching = dummy_derivative_graph!(state, jac, (ag, nothing); state_priority,
738+
kwargs...)
739+
tearing_reassemble(state, var_eq_matching, ag; simplify = simplify)
715740
end

src/structural_transformation/tearing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function masked_cumsum!(A::Vector)
2222
end
2323

2424
function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching,
25-
var_rename, eq_rename, nelim)
25+
var_rename, eq_rename, nelim_eq, nelim_var)
2626
dig = DiCMOBiGraph{true}(graph, var_eq_matching)
2727

2828
# Update bipartite graph
@@ -31,7 +31,7 @@ function contract_variables(graph::BipartiteGraph, var_eq_matching::Matching,
3131
for v′ in neighborhood(dig, v, Inf; dir = :in) if var_rename[v′] != 0]
3232
end
3333

34-
newgraph = BipartiteGraph(nsrcs(graph) - nelim, ndsts(graph) - nelim)
34+
newgraph = BipartiteGraph(nsrcs(graph) - nelim_eq, ndsts(graph) - nelim_var)
3535
for e in 𝑠vertices(graph)
3636
ne = eq_rename[e]
3737
ne == 0 && continue

src/structural_transformation/utils.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ end
4343
###
4444
### Structural check
4545
###
46-
function check_consistency(state::TearingState)
46+
function check_consistency(state::TearingState, ag = nothing)
4747
fullvars = state.fullvars
4848
@unpack graph, var_to_diff = state.structure
49-
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0,
49+
n_highest_vars = count(v -> var_to_diff[v] === nothing &&
50+
!isempty(𝑑neighbors(graph, v)) &&
51+
(ag === nothing || !haskey(ag, v) || ag[v] != v),
5052
vertices(var_to_diff))
5153
neqs = nsrcs(graph)
5254
is_balanced = n_highest_vars == neqs
@@ -69,11 +71,12 @@ function check_consistency(state::TearingState)
6971
# details, check the equation (15) of the original paper.
7072
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
7173
map(collect, edges(var_to_diff))])
72-
extended_var_eq_matching = maximal_matching(extended_graph)
74+
extended_var_eq_matching = maximal_matching(extended_graph, eq -> true,
75+
v -> ag === nothing || !haskey(ag, v))
7376

7477
unassigned_var = []
7578
for (vj, eq) in enumerate(extended_var_eq_matching)
76-
if eq === unassigned
79+
if eq === unassigned && (ag === nothing || !haskey(ag, vj))
7780
push!(unassigned_var, fullvars[vj])
7881
end
7982
end
@@ -228,7 +231,7 @@ function find_solvables!(state::TearingState; kwargs...)
228231
return nothing
229232
end
230233

231-
function linear_subsys_adjmat!(state::TransformationState)
234+
function linear_subsys_adjmat!(state::TransformationState; kwargs...)
232235
graph = state.structure.graph
233236
if state.structure.solvable_graph === nothing
234237
state.structure.solvable_graph = BipartiteGraph(nsrcs(graph), ndsts(graph))
@@ -240,7 +243,7 @@ function linear_subsys_adjmat!(state::TransformationState)
240243
coeffs = Int[]
241244
to_rm = Int[]
242245
for i in eachindex(eqs)
243-
all_int_vars, rhs = find_eq_solvables!(state, i, to_rm, coeffs)
246+
all_int_vars, rhs = find_eq_solvables!(state, i, to_rm, coeffs; kwargs...)
244247

245248
# Check if all states in the equation is both linear and homogeneous,
246249
# i.e. it is in the form of

src/systems/abstractsystem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,13 +1035,14 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
10351035
has_io = io !== nothing
10361036
has_io && markio!(state, io...)
10371037
state, input_idxs = inputs_to_parameters!(state, io)
1038-
sys = alias_elimination!(state)
1038+
sys, ag = alias_elimination!(state; kwargs...)
1039+
#ag = AliasGraph(length(ag))
10391040
# TODO: avoid construct `TearingState` again.
1040-
state = TearingState(sys)
1041-
has_io && markio!(state, io..., check = false)
1042-
check_consistency(state)
1043-
find_solvables!(state; kwargs...)
1044-
sys = dummy_derivative(sys, state; simplify)
1041+
#state = TearingState(sys)
1042+
#has_io && markio!(state, io..., check = false)
1043+
check_consistency(state, ag)
1044+
#find_solvables!(state; kwargs...)
1045+
sys = dummy_derivative(sys, state, ag; simplify)
10451046
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
10461047
@set! sys.observed = topsort_equations(observed(sys), fullstates)
10471048
invalidate_cache!(sys)

0 commit comments

Comments
 (0)