@@ -95,25 +95,47 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
95
95
sparse (I, J, true )
96
96
end
97
97
98
- function gen_nlsolve ( eqs, vars, u0map:: AbstractDict , assignments, deps, var2assignment; checkbounds= true )
98
+ function gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict , assignments, deps, var2assignment; checkbounds= true )
99
99
isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
100
100
length (eqs) == length (vars) || throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
101
101
rhss = map (x-> x. rhs, eqs)
102
102
# We use `vars` instead of `graph` to capture parameters, too.
103
- paramset = Set {Any} (Iterators . flatten ( ModelingToolkit. vars (r) for r in rhss) )
103
+ paramset = ModelingToolkit. vars (r for r in rhss)
104
104
105
+ # Compute necessary assignments for the nlsolve expr
105
106
init_assignments = [var2assignment[p] for p in paramset if haskey (var2assignment, p)]
106
107
tmp = [init_assignments]
107
108
# `deps[init_assignments]` gives the dependency of `init_assignments`
108
- while (next_assignments = reduce (vcat, deps[init_assignments]); ! isempty (next_assignments))
109
+ successors = Dict {Int,Vector{Int}} ()
110
+ while true
111
+ next_assignments = reduce (vcat, deps[init_assignments])
112
+ isempty (next_assignments) && break
109
113
init_assignments = next_assignments
110
114
push! (tmp, init_assignments)
111
115
end
112
- needed_assignments = mapreduce (i-> assignments[i], vcat, unique (reverse (tmp)))
113
- extravars = Set {Any} (Iterators. flatten (ModelingToolkit. vars (r. rhs) for r in needed_assignments))
116
+ needed_assignments_idxs = reduce (vcat, unique (reverse (tmp)))
117
+ needed_assignments = assignments[needed_assignments_idxs]
118
+
119
+ # Compute `params`. They are like enclosed variables
120
+ rhsvars = [ModelingToolkit. vars (r. rhs) for r in needed_assignments]
121
+ is_vars_independent = isdisjoint .((vars,), rhsvars)
122
+ inner_assignments = []; outer_idxs = Int[]
123
+ outer_assignments = []; inner_idxs = Int[]
124
+ for (i, ind) in enumerate (is_vars_independent)
125
+ a = needed_assignments[i]
126
+ if ind
127
+ push! (outer_assignments, a)
128
+ push! (outer_idxs, i)
129
+ else
130
+ push! (inner_assignments, a)
131
+ push! (inner_idxs, i)
132
+ end
133
+ end
134
+ extravars = reduce (union!, rhsvars[inner_idxs], init= Set ())
114
135
union! (paramset, extravars)
115
- # these are not the subject of the root finding
116
- setdiff! (paramset, vars); setdiff! (paramset, map (a-> a. lhs, needed_assignments))
136
+ setdiff! (paramset, vars)
137
+ setdiff! (paramset, [needed_assignments[i]. lhs for i in inner_idxs])
138
+ union! (paramset, [needed_assignments[i]. lhs for i in outer_idxs])
117
139
params = collect (paramset)
118
140
119
141
# splatting to tighten the type
@@ -141,7 +163,7 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assi
141
163
],
142
164
[],
143
165
Let (
144
- needed_assignments,
166
+ needed_assignments[inner_idxs] ,
145
167
isscalar ? rhss[1 ] : MakeArray (rhss, SVector)
146
168
)
147
169
) |> SymbolicUtils. Code. toexpr
@@ -157,7 +179,16 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict, assignments, deps, var2assi
157
179
)
158
180
end )
159
181
182
+ preassignments = []
183
+ for i in outer_idxs
184
+ ii = needed_assignments_idxs[i]
185
+ is_not_prepended_assignment[ii] || continue
186
+ is_not_prepended_assignment[ii] = false
187
+ push! (preassignments, assignments[ii])
188
+ end
189
+
160
190
nlsolve_expr = Assignment[
191
+ preassignments
161
192
fname ← @RuntimeGeneratedFunction (f)
162
193
DestructuredArgs (vars, inbounds= ! checkbounds) ← solver_call
163
194
]
@@ -197,6 +228,7 @@ function build_torn_function(
197
228
198
229
assignments, deps, bf_states = tearing_assignments (sys)
199
230
var2assignment = Dict {Any,Int} (eq. lhs => i for (i, eq) in enumerate (assignments))
231
+ is_not_prepended_assignment = trues (length (assignments))
200
232
201
233
torn_expr = Assignment[]
202
234
@@ -209,18 +241,7 @@ function build_torn_function(
209
241
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
210
242
isempty (torn_eqs_idxs) && continue
211
243
if length (torn_eqs_idxs) <= max_inlining_size
212
- nlsolve_expr = gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds= checkbounds)
213
- #=
214
- # a temporary vector that we need to reverse to get the correct
215
- # dependency evaluation order.
216
- local_deps = Vector{Int}[]
217
- init_deps = [var2assignment[p] for p in params if haskey(var2assignment, p)]
218
- push!(local_deps, init_deps)
219
- while (next_deps = reduce(vcat, deps[init_deps]); !isempty(next_deps))
220
- init_deps = next_deps
221
- push!(local_deps, init_deps)
222
- end
223
- =#
244
+ nlsolve_expr = gen_nlsolve! (is_not_prepended_assignment, eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds= checkbounds)
224
245
append! (torn_expr, nlsolve_expr)
225
246
push! (nlsolve_scc_idxs, i)
226
247
else
@@ -256,7 +277,7 @@ function build_torn_function(
256
277
],
257
278
[],
258
279
pre (Let (
259
- [torn_expr; assignments],
280
+ [torn_expr; assignments[is_not_prepended_assignment] ],
260
281
funbody
261
282
))
262
283
),
0 commit comments