Skip to content

Commit 2521506

Browse files
committed
Format and some clean up
1 parent 581247e commit 2521506

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ function check_diff_graph(var_to_diff, fullvars)
217217
end
218218
=#
219219

220-
function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing; simplify = false)
220+
function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
221+
simplify = false)
221222
@unpack fullvars, sys, structure = state
222223
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
223224

@@ -555,9 +556,11 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
555556
if length(diff_vars_set) != length(diff_vars)
556557
error("Tearing internal error: lowering DAE into semi-implicit ODE failed!")
557558
end
559+
solved_variables_set = BitSet(solved_variables)
560+
ag === nothing || union!(solved_variables_set, keys(ag))
558561
invvarsperm = [diff_vars;
559562
setdiff!(setdiff(1:ndsts(graph), diff_vars_set),
560-
union!(BitSet(solved_variables), keys(ag)))]
563+
solved_variables_set)]
561564
varsperm = zeros(Int, ndsts(graph))
562565
for (i, v) in enumerate(invvarsperm)
563566
varsperm[v] = i
@@ -689,7 +692,8 @@ end
689692
Perform index reduction and use the dummy derivative technique to ensure that
690693
the system is balanced.
691694
"""
692-
function dummy_derivative(sys, state = TearingState(sys), ag = nothing; simplify = false, kwargs...)
695+
function dummy_derivative(sys, state = TearingState(sys), ag = nothing; simplify = false,
696+
kwargs...)
693697
jac = let state = state
694698
(eqs, vars) -> begin
695699
symeqs = EquationsView(state)[eqs]
@@ -711,6 +715,7 @@ function dummy_derivative(sys, state = TearingState(sys), ag = nothing; simplify
711715
p
712716
end
713717
end
714-
var_eq_matching = dummy_derivative_graph!(state, jac, (ag, nothing); state_priority, kwargs...)
718+
var_eq_matching = dummy_derivative_graph!(state, jac, (ag, nothing); state_priority,
719+
kwargs...)
715720
tearing_reassemble(state, var_eq_matching, ag; simplify = simplify)
716721
end

src/structural_transformation/utils.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,9 @@ end
4646
function check_consistency(state::TearingState, ag = nothing)
4747
fullvars = state.fullvars
4848
@unpack graph, var_to_diff = state.structure
49-
if ag === nothing
50-
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0,
51-
vertices(var_to_diff))
52-
else
53-
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 && !haskey(ag, v),
54-
vertices(var_to_diff))
55-
end
49+
n_highest_vars = count(v -> length(outneighbors(var_to_diff, v)) == 0 &&
50+
(ag === nothing || !haskey(ag, v)),
51+
vertices(var_to_diff))
5652
neqs = nsrcs(graph)
5753
is_balanced = n_highest_vars == neqs
5854

@@ -74,11 +70,12 @@ function check_consistency(state::TearingState, ag = nothing)
7470
# details, check the equation (15) of the original paper.
7571
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
7672
map(collect, edges(var_to_diff))])
77-
extended_var_eq_matching = maximal_matching(extended_graph, eq->true, v->ag !== nothing || !haskey(ag, v))
73+
extended_var_eq_matching = maximal_matching(extended_graph, eq -> true,
74+
v -> ag === nothing || !haskey(ag, v))
7875

7976
unassigned_var = []
8077
for (vj, eq) in enumerate(extended_var_eq_matching)
81-
if eq === unassigned && !haskey(ag, vj)
78+
if eq === unassigned && (ag === nothing || !haskey(ag, vj))
8279
push!(unassigned_var, fullvars[vj])
8380
end
8481
end

src/systems/alias_elimination.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function alias_elimination!(state::TearingState)
5353
complete!(state.structure)
5454
graph_orig = copy(state.structure.graph)
5555
ag, mm, complete_ag, complete_mm, updated_diff_vars = alias_eliminate_graph!(state)
56-
isempty(ag) && return sys
56+
isempty(ag) && return sys, ag
5757

5858
fullvars = state.fullvars
5959
@unpack var_to_diff, graph, solvable_graph = state.structure
@@ -180,7 +180,6 @@ function alias_elimination!(state::TearingState)
180180
state.structure.graph = new_graph
181181
state.structure.solvable_graph = new_solvable_graph
182182
state.structure.eq_to_diff = new_eq_to_diff
183-
@show length(new_eq_to_diff), nsrcs(new_graph), nsrcs(new_solvable_graph), length(eqs)
184183

185184
#=
186185
new_graph = BipartiteGraph(n_new_eqs, n_new_vars)

0 commit comments

Comments
 (0)