Skip to content

Commit 07644d8

Browse files
committed
Tearing differentiated variables on priority
1 parent 2400c58 commit 07644d8

File tree

2 files changed

+57
-24
lines changed

2 files changed

+57
-24
lines changed

src/structural_transformation/bipartite_tearing/modia_tearing.jl

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,62 @@ function try_assign_eq!(ict::IncrementalCycleTracker, vj::Integer, eq::Integer)
99
end
1010
end
1111

12-
function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int},
13-
vs::Vector{Int})
12+
function try_assign_eq!(ict::IncrementalCycleTracker, vars, v_active, eq::Integer,
13+
condition::F = _ -> true) where {F}
1414
G = ict.graph
15-
vActive = BitSet(vs)
15+
for vj in vars
16+
(vj in v_active && G.matching[vj] === unassigned && condition(vj)) || continue
17+
try_assign_eq!(ict, vj, eq) && return true
18+
end
19+
return false
20+
end
1621

22+
function tearEquations!(ict::IncrementalCycleTracker, Gsolvable, es::Vector{Int},
23+
v_active::BitSet, isder′::F) where {F}
24+
check_der = isder′ !== nothing
25+
if check_der
26+
has_der = Ref(false)
27+
isder = let has_der = has_der, isder′ = isder′
28+
v -> begin
29+
r = isder′(v)
30+
has_der[] |= r
31+
r
32+
end
33+
end
34+
end
1735
for eq in es # iterate only over equations that are not in eSolvedFixed
18-
for vj in Gsolvable[eq]
19-
if G.matching[vj] === unassigned && (vj in vActive)
20-
r = try_assign_eq!(ict, vj, eq)
21-
r && break
36+
vs = Gsolvable[eq]
37+
#=
38+
if check_der
39+
# if there're differentiated variables, then only consider them
40+
try_assign_eq!(ict, vs, v_active, eq, isder)
41+
@show has_der[]
42+
if has_der[]
43+
has_der[] = false
44+
continue
2245
end
2346
end
47+
=#
48+
try_assign_eq!(ict, vs, v_active, eq)
2449
end
2550

2651
return ict
2752
end
2853

29-
function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, vars)
54+
function tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, eqs, vars,
55+
isder::F) where {F}
3056
ict = IncrementalCycleTracker(DiCMOBiGraph{true}(graph); dir = :in)
31-
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars)
57+
tearEquations!(ict, solvable_graph.fadjlist, eqs, vars, isder)
3258
for var in vars
3359
var_eq_matching[var] = ict.graph.matching[var]
3460
end
3561
return nothing
3662
end
3763

38-
function tear_graph_modia(structure::SystemStructure, ::Type{U} = Unassigned;
39-
varfilter = v -> true, eqfilter = eq -> true) where {U}
64+
function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
65+
::Type{U} = Unassigned;
66+
varfilter::F2 = v -> true,
67+
eqfilter::F3 = eq -> true) where {F, U, F2, F3}
4068
# It would be possible here to simply iterate over all variables and attempt to
4169
# use tearEquations! to produce a matching that greedily selects the minimal
4270
# number of torn variables. However, we can do this process faster if we first
@@ -52,14 +80,22 @@ function tear_graph_modia(structure::SystemStructure, ::Type{U} = Unassigned;
5280
var_eq_matching = complete(maximal_matching(graph, eqfilter, varfilter, U))
5381
var_sccs::Vector{Union{Vector{Int}, Int}} = find_var_sccs(graph, var_eq_matching)
5482

83+
ieqs = Int[]
84+
filtered_vars = BitSet()
5585
for vars in var_sccs
56-
filtered_vars = filter(varfilter, vars)
57-
ieqs = Int[var_eq_matching[v]
58-
for v in filtered_vars if var_eq_matching[v] !== unassigned]
5986
for var in vars
87+
if varfilter(var)
88+
push!(filtered_vars, var)
89+
if var_eq_matching[var] !== unassigned
90+
push!(ieqs, var_eq_matching[var])
91+
end
92+
end
6093
var_eq_matching[var] = unassigned
6194
end
62-
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars)
95+
tear_graph_block_modia!(var_eq_matching, graph, solvable_graph, ieqs, filtered_vars,
96+
isder)
97+
empty!(ieqs)
98+
empty!(filtered_vars)
6399
end
64100

65101
return var_eq_matching

src/structural_transformation/partial_state_selection.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
8484
end
8585
filter!(var -> ict.graph.matching[var] === unassigned, to_tear_vars)
8686
filter!(eq -> invview(ict.graph.matching)[eq] === unassigned, to_tear_eqs)
87-
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, to_tear_vars)
87+
tearEquations!(ict, solvable_graph.fadjlist, to_tear_eqs, BitSet(to_tear_vars),
88+
nothing)
8889
for var in to_tear_vars
8990
var_eq_matching[var] = ict.graph.matching[var]
9091
end
@@ -177,7 +178,6 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
177178
invgraph = invview(graph)
178179

179180
eqlevel, _ = compute_diff_level(diff_to_eq)
180-
varlevel, _ = compute_diff_level(diff_to_var)
181181

182182
var_sccs = find_var_sccs(graph, var_eq_matching)
183183
eqcolor = falses(nsrcs(graph))
@@ -191,7 +191,7 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
191191
iszero(maxlevel) && continue
192192

193193
rank_matching = Matching(nvars)
194-
for level in maxlevel:-1:1
194+
for _ in maxlevel:-1:1
195195
eqs = filter(eq -> diff_to_eq[eq] !== nothing, eqs)
196196
nrows = length(eqs)
197197
iszero(nrows) && break
@@ -254,13 +254,10 @@ function dummy_derivative_graph!(structure::SystemStructure, var_eq_matching, ja
254254
isdiffed = let diff_to_var = diff_to_var, dummy_derivatives_set = dummy_derivatives_set
255255
v -> diff_to_var[v] !== nothing && !(v in dummy_derivatives_set)
256256
end
257-
should_consider = let graph = graph, isdiffed = isdiffed
258-
eq -> !any(isdiffed, 𝑠neighbors(graph, eq))
259-
end
260257

261-
var_eq_matching = tear_graph_modia(structure, Union{Unassigned, SelectedState};
262-
varfilter = can_eliminate,
263-
eqfilter = should_consider)
258+
var_eq_matching = tear_graph_modia(structure, isdiffed,
259+
Union{Unassigned, SelectedState};
260+
varfilter = can_eliminate)
264261
for v in eachindex(var_eq_matching)
265262
can_eliminate(v) && continue
266263
var_eq_matching[v] = SelectedState()

0 commit comments

Comments
 (0)