11using LinearAlgebra
22
3- using ModelingToolkit: process_events, get_preprocess_constants
3+ using ModelingToolkit: process_events
44
55const MAX_INLINE_NLSOLVE_SIZE = 8
66
@@ -96,136 +96,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_
9696 sparse (I, J, true , length (eqs_idxs), length (states_idxs))
9797end
9898
99- function gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict ,
100- assignments, (deps, invdeps), var2assignment; checkbounds = true )
101- isempty (vars) && throw (ArgumentError (" vars may not be empty" ))
102- length (eqs) == length (vars) ||
103- throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" ))
104- rhss = map (x -> x. rhs, eqs)
105- # We use `vars` instead of `graph` to capture parameters, too.
106- paramset = ModelingToolkit. vars (r for r in rhss)
107-
108- # Compute necessary assignments for the nlsolve expr
109- init_assignments = [var2assignment[p] for p in paramset if haskey (var2assignment, p)]
110- if isempty (init_assignments)
111- needed_assignments_idxs = Int[]
112- needed_assignments = similar (assignments, 0 )
113- else
114- tmp = [init_assignments]
115- # `deps[init_assignments]` gives the dependency of `init_assignments`
116- while true
117- next_assignments = unique (reduce (vcat, deps[init_assignments]))
118- isempty (next_assignments) && break
119- init_assignments = next_assignments
120- push! (tmp, init_assignments)
121- end
122- needed_assignments_idxs = unique (reduce (vcat, reverse (tmp)))
123- needed_assignments = assignments[needed_assignments_idxs]
124- end
125-
126- # Compute `params`. They are like enclosed variables
127- rhsvars = [ModelingToolkit. vars (r. rhs) for r in needed_assignments]
128- vars_set = Set (vars)
129- outer_set = BitSet ()
130- inner_set = BitSet ()
131- for (i, vs) in enumerate (rhsvars)
132- j = needed_assignments_idxs[i]
133- if isdisjoint (vars_set, vs)
134- push! (outer_set, j)
135- else
136- push! (inner_set, j)
137- end
138- end
139- init_refine = BitSet ()
140- for i in inner_set
141- union! (init_refine, invdeps[i])
142- end
143- intersect! (init_refine, outer_set)
144- setdiff! (outer_set, init_refine)
145- union! (inner_set, init_refine)
146-
147- next_refine = BitSet ()
148- while true
149- for i in init_refine
150- id = invdeps[i]
151- isempty (id) && break
152- union! (next_refine, id)
153- end
154- intersect! (next_refine, outer_set)
155- isempty (next_refine) && break
156- setdiff! (outer_set, next_refine)
157- union! (inner_set, next_refine)
158-
159- init_refine, next_refine = next_refine, init_refine
160- empty! (next_refine)
161- end
162- global2local = Dict (j => i for (i, j) in enumerate (needed_assignments_idxs))
163- inner_idxs = [global2local[i] for i in collect (inner_set)]
164- outer_idxs = [global2local[i] for i in collect (outer_set)]
165- extravars = reduce (union!, rhsvars[inner_idxs], init = Set ())
166- union! (paramset, extravars)
167- setdiff! (paramset, vars)
168- setdiff! (paramset, [needed_assignments[i]. lhs for i in inner_idxs])
169- union! (paramset, [needed_assignments[i]. lhs for i in outer_idxs])
170- params = collect (paramset)
171-
172- # splatting to tighten the type
173- u0 = []
174- for v in vars
175- v in keys (u0map) || (push! (u0, 1e-3 ); continue )
176- u = substitute (v, u0map)
177- for i in 1 : length (u0map)
178- u = substitute (u, u0map)
179- u isa Number && (push! (u0, u); break )
180- end
181- u isa Number || error (" $v doesn't have a default." )
182- end
183- u0 = [u0... ]
184- # specialize on the scalar case
185- isscalar = length (u0) == 1
186- u0 = isscalar ? u0[1 ] : SVector (u0... )
187-
188- fname = gensym (" fun" )
189- # f is the function to find roots on
190- if isscalar
191- funex = rhss[1 ]
192- pre = get_preprocess_constants (funex)
193- else
194- funex = MakeArray (rhss, SVector)
195- pre = get_preprocess_constants (rhss)
196- end
197- f = Func (
198- [DestructuredArgs (vars, inbounds = ! checkbounds)
199- DestructuredArgs (params, inbounds = ! checkbounds)],
200- [],
201- pre (Let (needed_assignments[inner_idxs],
202- funex,
203- false ))) |> SymbolicUtils. Code. toexpr
204-
205- # solver call contains code to call the root-finding solver on the function f
206- solver_call = LiteralExpr (quote
207- $ numerical_nlsolve ($ fname,
208- # initial guess
209- $ u0,
210- # "captured variables"
211- ($ (params... ),))
212- end )
213-
214- preassignments = []
215- for i in outer_idxs
216- ii = needed_assignments_idxs[i]
217- is_not_prepended_assignment[ii] || continue
218- is_not_prepended_assignment[ii] = false
219- push! (preassignments, assignments[ii])
220- end
221-
222- nlsolve_expr = Assignment[preassignments
223- fname ← drop_expr (@RuntimeGeneratedFunction (f))
224- DestructuredArgs (vars, inbounds = ! checkbounds) ← solver_call]
225-
226- nlsolve_expr
227- end
228-
22999"""
230100 find_solve_sequence(sccs, vars)
231101
@@ -242,136 +112,3 @@ function find_solve_sequence(sccs, vars)
242112 return find_solve_sequence (sccs, vars′)
243113 end
244114end
245-
246- function build_observed_function (state, ts, var_eq_matching, var_sccs,
247- is_solver_unknown_idxs,
248- assignments,
249- deps,
250- sol_states,
251- var2assignment;
252- expression = false ,
253- output_type = Array,
254- checkbounds = true )
255- is_not_prepended_assignment = trues (length (assignments))
256- if (isscalar = ! (ts isa AbstractVector))
257- ts = [ts]
258- end
259- ts = unwrap .(Symbolics. scalarize (ts))
260-
261- vars = Set ()
262- sys = state. sys
263- foreach (Base. Fix1 (vars!, vars), ts)
264- ivs = independent_variables (sys)
265- dep_vars = collect (setdiff (vars, ivs))
266-
267- fullvars = state. fullvars
268- s = state. structure
269- unknown_vars = fullvars[is_solver_unknown_idxs]
270- algvars = fullvars[.! is_solver_unknown_idxs]
271-
272- required_algvars = Set (intersect (algvars, vars))
273- obs = observed (sys)
274- observed_idx = Dict (x. lhs => i for (i, x) in enumerate (obs))
275- namespaced_to_obs = Dict (unknowns (sys, x. lhs) => x. lhs for x in obs)
276- namespaced_to_sts = Dict (unknowns (sys, x) => x for x in unknowns (sys))
277- sts = Set (unknowns (sys))
278-
279- # FIXME : This is a rather rough estimate of dependencies. We assume
280- # the expression depends on everything before the `maxidx`.
281- subs = Dict ()
282- maxidx = 0
283- for (i, s) in enumerate (dep_vars)
284- idx = get (observed_idx, s, nothing )
285- if idx != = nothing
286- idx > maxidx && (maxidx = idx)
287- else
288- s′ = get (namespaced_to_obs, s, nothing )
289- if s′ != = nothing
290- subs[s] = s′
291- s = s′
292- idx = get (observed_idx, s, nothing )
293- end
294- if idx != = nothing
295- idx > maxidx && (maxidx = idx)
296- elseif ! (s in sts)
297- s′ = get (namespaced_to_sts, s, nothing )
298- if s′ != = nothing
299- subs[s] = s′
300- continue
301- end
302- throw (ArgumentError (" $s is either an observed nor an unknown variable." ))
303- end
304- continue
305- end
306- end
307- ts = map (t -> substitute (t, subs), ts)
308- vs = Set ()
309- for idx in 1 : maxidx
310- vars! (vs, obs[idx]. rhs)
311- union! (required_algvars, intersect (algvars, vs))
312- empty! (vs)
313- end
314- for eq in assignments
315- vars! (vs, eq. rhs)
316- union! (required_algvars, intersect (algvars, vs))
317- empty! (vs)
318- end
319-
320- varidxs = findall (x -> x in required_algvars, fullvars)
321- subset = find_solve_sequence (var_sccs, varidxs)
322- if ! isempty (subset)
323- eqs = equations (sys)
324-
325- nested_torn_vars_idxs = []
326- for iscc in subset
327- torn_vars_idxs = Int[var
328- for var in var_sccs[iscc]
329- if var_eq_matching[var] != = unassigned]
330- isempty (torn_vars_idxs) || push! (nested_torn_vars_idxs, torn_vars_idxs)
331- end
332- torn_eqs = [[eqs[var_eq_matching[i]] for i in idxs]
333- for idxs in nested_torn_vars_idxs]
334- torn_vars = [fullvars[idxs] for idxs in nested_torn_vars_idxs]
335- u0map = defaults (sys)
336- assignments = copy (assignments)
337- solves = map (zip (torn_eqs, torn_vars)) do (eqs, vars)
338- gen_nlsolve! (is_not_prepended_assignment, eqs, vars,
339- u0map, assignments, deps, var2assignment;
340- checkbounds = checkbounds)
341- end
342- else
343- solves = []
344- end
345-
346- subs = []
347- for sym in vars
348- eqidx = get (observed_idx, sym, nothing )
349- eqidx === nothing && continue
350- push! (subs, sym ← obs[eqidx]. rhs)
351- end
352- pre = get_postprocess_fbody (sys)
353- cpre = get_preprocess_constants ([obs[1 : maxidx];
354- isscalar ? ts[1 ] : MakeArray (ts, output_type)])
355- pre2 = x -> pre (cpre (x))
356- ex = Code. toexpr (
357- Func (
358- [DestructuredArgs (unknown_vars, inbounds = ! checkbounds)
359- DestructuredArgs (parameters (sys), inbounds = ! checkbounds)
360- independent_variables (sys)],
361- [],
362- pre2 (Let (
363- [collect (Iterators. flatten (solves))
364- assignments[is_not_prepended_assignment]
365- map (eq -> eq. lhs ← eq. rhs, obs[1 : maxidx])
366- subs],
367- isscalar ? ts[1 ] : MakeArray (ts, output_type),
368- false ))),
369- sol_states)
370-
371- expression ? ex : drop_expr (@RuntimeGeneratedFunction (ex))
372- end
373-
374- struct ODAEProblem{iip} end
375-
376- @deprecate ODAEProblem (args... ; kw... ) ODEProblem (args... ; kw... )
377- @deprecate ODAEProblem {iip} (args... ; kw... ) where {iip} ODEProblem {iip} (args... ; kw... )
0 commit comments