@@ -95,41 +95,26 @@ function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_s
95
95
sparse (I, J, true )
96
96
end
97
97
98
- """
99
- exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true, assignments)
100
-
101
- Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
102
- as well as a call to the root-finding solver.
103
-
104
- `exprs` is a two element vector
105
- ```
106
- exprs = [fname = f, numerical_nlsolve(fname, ...)]
107
- ```
108
-
109
- # Arguments:
110
- - `eqs`: Equations to find roots of.
111
- - `vars`: ???
112
- - `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
113
- - `checkbounds`: Apply bounds checking in the generated code.
114
- """
115
- function gen_nlsolve (eqs, vars, u0map:: AbstractDict ; checkbounds= true , assignments)
98
+ function gen_nlsolve (eqs, vars, u0map:: AbstractDict , assignments, deps, var2assignment; checkbounds= true )
116
99
isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
117
100
length (eqs) == length (vars) || throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
118
101
rhss = map (x-> x. rhs, eqs)
119
102
# We use `vars` instead of `graph` to capture parameters, too.
120
- allvars = Set (Iterators. flatten (ModelingToolkit. vars (r) for r in rhss))
121
- vars_set = Set (vars)
122
- params = setdiff (allvars, vars_set) # these are not the subject of the root finding
123
- # needed_assignments = filter(a->a.lhs in params, assignments)
124
- needed_assignments = assignments[1 : findlast (a-> a. lhs in params, assignments)]
125
- params = setdiff (params, [a. lhs for a in needed_assignments])
126
- @show needed_assignments, params
127
- for a in needed_assignments
128
- ModelingToolkit. vars! (params, a. rhs)
103
+ paramset = Set (Iterators. flatten (ModelingToolkit. vars (r) for r in rhss))
104
+
105
+ init_assignments = [var2assignment[p] for p in paramset if haskey (var2assignment, p)]
106
+ tmp = [init_assignments]
107
+ # `deps[init_assignments]` gives the dependency of `init_assignments`
108
+ while (next_assignments = reduce (vcat, deps[init_assignments]); ! isempty (next_assignments))
109
+ init_assignments = next_assignments
110
+ push! (tmp, init_assignments)
129
111
end
130
- params = setdiff (params, vars_set) # these are not the subject of the root finding
131
- @show params
132
- # inductor1₊v, inductor2₊v
112
+ needed_assignments = mapreduce (i-> assignments[i], vcat, reverse (tmp))
113
+ extravars = Set (Iterators. flatten (ModelingToolkit. vars (r. rhs) for r in needed_assignments))
114
+ union! (paramset, extravars)
115
+ # these are not the subject of the root finding
116
+ setdiff! (paramset, vars); setdiff! (paramset, map (a-> a. lhs, needed_assignments))
117
+ params = collect (paramset)
133
118
134
119
# splatting to tighten the type
135
120
u0 = []
@@ -152,10 +137,13 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignmen
152
137
f = Func (
153
138
[
154
139
DestructuredArgs (vars, inbounds= ! checkbounds)
155
- DestructuredArgs (collect ( params) , inbounds= ! checkbounds)
140
+ DestructuredArgs (params, inbounds= ! checkbounds)
156
141
],
157
142
[],
158
- Let (needed_assignments, isscalar ? rhss[1 ] : MakeArray (rhss, SVector))
143
+ Let (
144
+ needed_assignments,
145
+ isscalar ? rhss[1 ] : MakeArray (rhss, SVector)
146
+ )
159
147
) |> SymbolicUtils. Code. toexpr
160
148
161
149
# solver call contains code to call the root-finding solver on the function f
@@ -169,10 +157,12 @@ function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true, assignmen
169
157
)
170
158
end )
171
159
172
- [
173
- fname ← @RuntimeGeneratedFunction (f)
174
- DestructuredArgs (vars, inbounds= ! checkbounds) ← solver_call
175
- ]
160
+ nlsolve_expr = Assignment[
161
+ fname ← @RuntimeGeneratedFunction (f)
162
+ DestructuredArgs (vars, inbounds= ! checkbounds) ← solver_call
163
+ ]
164
+
165
+ nlsolve_expr
176
166
end
177
167
178
168
function build_torn_function (
@@ -205,20 +195,33 @@ function build_torn_function(
205
195
states_idxs = collect (diffvars_range (s))
206
196
mass_matrix_diag = ones (length (states_idxs))
207
197
208
- assignments, bf_states = tearing_assignments (sys)
209
- torn_expr = []
198
+ assignments, deps, bf_states = tearing_assignments (sys)
199
+ var2assignment = Dict {Any,Int} (eq. lhs => i for (i, eq) in enumerate (assignments))
200
+
201
+ torn_expr = Assignment[]
210
202
211
203
defs = defaults (sys)
212
204
nlsolve_scc_idxs = Int[]
213
205
214
206
needs_extending = false
215
207
@views for (i, scc) in enumerate (var_sccs)
216
- # torn_vars = [s.fullvars[var] for var in scc if var_eq_matching[var] !== unassigned]
217
208
torn_vars_idxs = Int[var for var in scc if var_eq_matching[var] != = unassigned]
218
209
torn_eqs_idxs = [var_eq_matching[var] for var in torn_vars_idxs]
219
210
isempty (torn_eqs_idxs) && continue
220
211
if length (torn_eqs_idxs) <= max_inlining_size
221
- append! (torn_expr, gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, checkbounds= checkbounds, assignments= assignments))
212
+ nlsolve_expr = gen_nlsolve (eqs[torn_eqs_idxs], s. fullvars[torn_vars_idxs], defs, assignments, deps, var2assignment, checkbounds= checkbounds)
213
+ #=
214
+ # a temporary vector that we need to reverse to get the correct
215
+ # dependency evaluation order.
216
+ local_deps = Vector{Int}[]
217
+ init_deps = [var2assignment[p] for p in params if haskey(var2assignment, p)]
218
+ push!(local_deps, init_deps)
219
+ while (next_deps = reduce(vcat, deps[init_deps]); !isempty(next_deps))
220
+ init_deps = next_deps
221
+ push!(local_deps, init_deps)
222
+ end
223
+ =#
224
+ append! (torn_expr, nlsolve_expr)
222
225
push! (nlsolve_scc_idxs, i)
223
226
else
224
227
needs_extending = true
@@ -262,13 +265,13 @@ function build_torn_function(
262
265
if expression
263
266
expr, states
264
267
else
265
- observedfun = let sys= sys, dict= Dict (), assignments= assignments, bf_states= bf_states
268
+ observedfun = let sys= sys, dict= Dict (), assignments= assignments, deps = deps, bf_states= bf_states, var2assignment = var2assignment
266
269
function generated_observed (obsvar, u, p, t)
267
270
obs = get! (dict, value (obsvar)) do
268
271
build_observed_function (sys, obsvar, var_eq_matching, var_sccs,
272
+ assignments, deps, bf_states, var2assignment,
269
273
checkbounds= checkbounds,
270
- assignments= assignments,
271
- bf_states= bf_states)
274
+ )
272
275
end
273
276
obs (u, p, t)
274
277
end
@@ -302,12 +305,14 @@ function find_solve_sequence(sccs, vars)
302
305
end
303
306
304
307
function build_observed_function (
305
- sys, ts, var_eq_matching, var_sccs;
308
+ sys, ts, var_eq_matching, var_sccs,
309
+ assignments,
310
+ deps,
311
+ bf_states,
312
+ var2assignment;
306
313
expression= false ,
307
314
output_type= Array,
308
315
checkbounds= true ,
309
- assignments,
310
- bf_states,
311
316
)
312
317
313
318
if (isscalar = ! (ts isa AbstractVector))
@@ -356,7 +361,8 @@ function build_observed_function(
356
361
torn_eqs = map (i-> map (v-> eqs[var_eq_matching[v]], var_sccs[i]), subset)
357
362
torn_vars = map (i-> map (v-> fullvars[v], var_sccs[i]), subset)
358
363
u0map = defaults (sys)
359
- solves = gen_nlsolve .(torn_eqs, torn_vars, (u0map,); checkbounds= checkbounds)
364
+ assignments = copy (assignments)
365
+ solves = gen_nlsolve .(torn_eqs, torn_vars, (u0map,), (assignments,), (deps,), (var2assignment,); checkbounds= checkbounds)
360
366
else
361
367
solves = []
362
368
end
0 commit comments