Skip to content

Commit 581247e

Browse files
committed
Contract removed equations away
1 parent c8fb745 commit 581247e

File tree

5 files changed

+18
-15
lines changed

5 files changed

+18
-15
lines changed

src/structural_transformation/partial_state_selection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
248248
end
249249
end
250250
if diff_va !== nothing
251+
# differentiated alias
251252
n_dummys = length(dummy_derivatives)
252253
needed = count(x -> x isa Int, diff_to_eq) - n_dummys
253254
n = 0

src/structural_transformation/symbolics_tearing.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ function check_diff_graph(var_to_diff, fullvars)
217217
end
218218
=#
219219

220-
function tearing_reassemble(state::TearingState, var_eq_matching; simplify = false)
220+
function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; simplify = false)
221221
@unpack fullvars, sys, structure = state
222222
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
223223

@@ -557,7 +557,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
557557
end
558558
invvarsperm = [diff_vars;
559559
setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
560-
BitSet(solved_variables))]
560+
union!(BitSet(solved_variables), keys(ag)))]
561561
varsperm = zeros(Int, ndsts(graph))
562562
for (i, v) in enumerate(invvarsperm)
563563
varsperm[v] = i
@@ -632,6 +632,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching; simplify = fal
632632
isdiffeq(eq) || continue
633633
obs_sub[eq.lhs] = eq.rhs
634634
end
635+
# TODO: compute the dependency correctly so that we don't have to do this
635636
obs = substitute.([oldobs; subeqs], (obs_sub,))
636637
@set! sys.observed = obs
637638
@set! state.sys = sys
@@ -688,7 +689,7 @@ end
688689
Perform index reduction and use the dummy derivative technique to ensure that
689690
the system is balanced.
690691
"""
691-
function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwargs...)
692+
function dummy_derivative(sys, state = TearingState(sys), ag = nothing; simplify = false, kwargs...)
692693
jac = let state = state
693694
(eqs, vars) -> begin
694695
symeqs = EquationsView(state)[eqs]
@@ -710,6 +711,6 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false, kwar
710711
p
711712
end
712713
end
713-
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority, kwargs...)
714-
tearing_reassemble(state, var_eq_matching; simplify = simplify)
714+
var_eq_matching = dummy_derivative_graph!(state, jac, (ag, nothing); state_priority, kwargs...)
715+
tearing_reassemble(state, var_eq_matching, ag; simplify = simplify)
715716
end

src/structural_transformation/utils.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ end
4646
function check_consistency(state::TearingState, ag = nothing)
4747
fullvars = state.fullvars
4848
@unpack graph, var_to_diff = state.structure
49-
#n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !isempty(𝑑neighbors(graph, v)),
50-
# vertices(var_to_diff))
51-
#n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !haskey(ag, v),
52-
# vertices(var_to_diff))
53-
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0,
54-
vertices(var_to_diff))
49+
if ag === nothing
50+
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0,
51+
vertices(var_to_diff))
52+
else
53+
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !haskey(ag, v),
54+
vertices(var_to_diff))
55+
end
5556
neqs = nsrcs(graph)
5657
is_balanced = n_highest_vars == neqs
5758

@@ -73,7 +74,7 @@ function check_consistency(state::TearingState, ag = nothing)
7374
# details, check the equation (15) of the original paper.
7475
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
7576
map(collect, edges(var_to_diff))])
76-
extended_var_eq_matching = maximal_matching(extended_graph, eq->true, v->!haskey(ag, v))
77+
extended_var_eq_matching = maximal_matching(extended_graph, eq->true, v->ag !== nothing || !haskey(ag, v))
7778

7879
unassigned_var = []
7980
for (vj, eq) in enumerate(extended_var_eq_matching)

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ function structural_simplify(sys::AbstractSystem, io = nothing; simplify = false
10411041
#has_io && markio!(state, io..., check = false)
10421042
check_consistency(state, ag)
10431043
#find_solvables!(state; kwargs...)
1044-
sys = dummy_derivative(sys, state; simplify)
1044+
sys = dummy_derivative(sys, state, ag; simplify)
10451045
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
10461046
@set! sys.observed = topsort_equations(observed(sys), fullstates)
10471047
invalidate_cache!(sys)

src/systems/alias_elimination.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ function alias_elimination!(state::TearingState)
167167
diff_to_var[j] === nothing && push!(newstates, fullvars[j])
168168
end
169169
end
170-
#=
171170
new_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
172171
new_solvable_graph = BipartiteGraph(n_new_eqs, ndsts(graph))
173172
new_eq_to_diff = DiffGraph(n_new_eqs)
@@ -182,8 +181,8 @@ function alias_elimination!(state::TearingState)
182181
state.structure.solvable_graph = new_solvable_graph
183182
state.structure.eq_to_diff = new_eq_to_diff
184183
@show length(new_eq_to_diff), nsrcs(new_graph), nsrcs(new_solvable_graph), length(eqs)
185-
=#
186184

185+
#=
187186
new_graph = BipartiteGraph(n_new_eqs, n_new_vars)
188187
new_solvable_graph = BipartiteGraph(n_new_eqs, n_new_vars)
189188
new_eq_to_diff = DiffGraph(n_new_eqs)
@@ -207,6 +206,7 @@ function alias_elimination!(state::TearingState)
207206
state.structure.eq_to_diff = complete(new_eq_to_diff)
208207
state.structure.var_to_diff = complete(new_var_to_diff)
209208
state.fullvars = new_fullvars
209+
=#
210210

211211
sys = state.sys
212212
@set! sys.eqs = eqs

0 commit comments

Comments
 (0)