@@ -95,7 +95,7 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
95
95
sparse (I, J, true )
96
96
end
97
97
98
- function gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict , assignments, deps, var2assignment; checkbounds= true )
98
+ function gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict , assignments, ( deps, invdeps) , var2assignment; checkbounds= true )
99
99
isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
100
100
length (eqs) == length (vars) || throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
101
101
rhss = map (x-> x. rhs, eqs)
@@ -106,7 +106,6 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
106
106
init_assignments = [var2assignment[p] for p in paramset if haskey (var2assignment, p)]
107
107
tmp = [init_assignments]
108
108
# `deps[init_assignments]` gives the dependency of `init_assignments`
109
- successors = Dict {Int,Vector{Int}} ()
110
109
while true
111
110
next_assignments = reduce (vcat, deps[init_assignments])
112
111
isempty (next_assignments) && break
@@ -118,19 +117,43 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
118
117
119
118
# Compute `params`. They are like enclosed variables
120
119
rhsvars = [ModelingToolkit. vars (r. rhs) for r in needed_assignments]
121
- is_vars_independent = isdisjoint .((vars,), rhsvars)
122
- inner_assignments = []; outer_idxs = Int[]
123
- outer_assignments = []; inner_idxs = Int[]
124
- for (i, ind) in enumerate (is_vars_independent)
125
- a = needed_assignments[i]
126
- if ind
127
- push! (outer_assignments, a)
128
- push! (outer_idxs, i)
120
+ vars_set = Set (vars)
121
+ outer_set = BitSet ()
122
+ inner_set = BitSet ()
123
+ for (i, vs) in enumerate (rhsvars)
124
+ j = needed_assignments_idxs[i]
125
+ if isdisjoint (vars_set, vs)
126
+ push! (outer_set, j)
129
127
else
130
- push! (inner_assignments, a)
131
- push! (inner_idxs, i)
128
+ push! (inner_set, j)
132
129
end
133
130
end
131
+ init_refine = BitSet ()
132
+ for i in inner_set
133
+ union! (init_refine, invdeps[i])
134
+ end
135
+ intersect! (init_refine, outer_set)
136
+ setdiff! (outer_set, init_refine)
137
+ union! (inner_set, init_refine)
138
+
139
+ next_refine = BitSet ()
140
+ while true
141
+ for i in init_refine
142
+ id = invdeps[i]
143
+ isempty (id) && break
144
+ union! (next_refine, id)
145
+ end
146
+ intersect! (next_refine, outer_set)
147
+ isempty (next_refine) && break
148
+ setdiff! (outer_set, next_refine)
149
+ union! (inner_set, next_refine)
150
+
151
+ init_refine, next_refine = next_refine, init_refine
152
+ empty! (next_refine)
153
+ end
154
+ global2local = Dict (j=> i for (i, j) in enumerate (needed_assignments_idxs))
155
+ inner_idxs = [global2local[i] for i in collect (inner_set)]
156
+ outer_idxs = [global2local[i] for i in collect (outer_set)]
134
157
extravars = reduce (union!, rhsvars[inner_idxs], init= Set ())
135
158
union! (paramset, extravars)
136
159
setdiff! (paramset, vars)
@@ -227,6 +250,12 @@ function build_torn_function(
227
250
mass_matrix_diag = ones (length (states_idxs))
228
251
229
252
assignments, deps, bf_states = tearing_assignments (sys)
253
+ invdeps = map (_-> BitSet (), deps)
254
+ for (i, d) in enumerate (deps)
255
+ for a in d
256
+ push! (invdeps[a], i)
257
+ end
258
+ end
230
259
var2assignment = Dict {Any,Int} (eq. lhs => i for (i, eq) in enumerate (assignments))
231
260
is_not_prepended_assignment = trues (length (assignments))
232
261
@@ -241,7 +270,7 @@ function build_torn_function(
241
270
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
242
271
isempty (torn_eqs_idxs) && continue
243
272
if length (torn_eqs_idxs) <= max_inlining_size
244
- nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds= checkbounds)
273
+ nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, ( deps, invdeps) , var2assignment, checkbounds= checkbounds)
245
274
append! (torn_expr, nlsolve_expr)
246
275
push! (nlsolve_scc_idxs, i)
247
276
else
0 commit comments