11using  LinearAlgebra
22
3- using  ModelingToolkit:  process_events, get_preprocess_constants 
3+ using  ModelingToolkit:  process_events
44
55const  MAX_INLINE_NLSOLVE_SIZE =  8 
66
@@ -96,136 +96,6 @@ function torn_system_with_nlsolve_jacobian_sparsity(state, var_eq_matching, var_
9696    sparse (I, J, true , length (eqs_idxs), length (states_idxs))
9797end 
9898
99- function  gen_nlsolve! (is_not_prepended_assignment, eqs, vars, u0map:: AbstractDict ,
100-         assignments, (deps, invdeps), var2assignment; checkbounds =  true )
101-     isempty (vars) &&  throw (ArgumentError (" vars may not be empty" 
102-     length (eqs) ==  length (vars) || 
103-         throw (ArgumentError (" vars must be of the same length as the number of equations to find the roots of" 
104-     rhss =  map (x ->  x. rhs, eqs)
105-     #  We use `vars` instead of `graph` to capture parameters, too.
106-     paramset =  ModelingToolkit. vars (r for  r in  rhss)
107- 
108-     #  Compute necessary assignments for the nlsolve expr
109-     init_assignments =  [var2assignment[p] for  p in  paramset if  haskey (var2assignment, p)]
110-     if  isempty (init_assignments)
111-         needed_assignments_idxs =  Int[]
112-         needed_assignments =  similar (assignments, 0 )
113-     else 
114-         tmp =  [init_assignments]
115-         #  `deps[init_assignments]` gives the dependency of `init_assignments`
116-         while  true 
117-             next_assignments =  unique (reduce (vcat, deps[init_assignments]))
118-             isempty (next_assignments) &&  break 
119-             init_assignments =  next_assignments
120-             push! (tmp, init_assignments)
121-         end 
122-         needed_assignments_idxs =  unique (reduce (vcat, reverse (tmp)))
123-         needed_assignments =  assignments[needed_assignments_idxs]
124-     end 
125- 
126-     #  Compute `params`. They are like enclosed variables
127-     rhsvars =  [ModelingToolkit. vars (r. rhs) for  r in  needed_assignments]
128-     vars_set =  Set (vars)
129-     outer_set =  BitSet ()
130-     inner_set =  BitSet ()
131-     for  (i, vs) in  enumerate (rhsvars)
132-         j =  needed_assignments_idxs[i]
133-         if  isdisjoint (vars_set, vs)
134-             push! (outer_set, j)
135-         else 
136-             push! (inner_set, j)
137-         end 
138-     end 
139-     init_refine =  BitSet ()
140-     for  i in  inner_set
141-         union! (init_refine, invdeps[i])
142-     end 
143-     intersect! (init_refine, outer_set)
144-     setdiff! (outer_set, init_refine)
145-     union! (inner_set, init_refine)
146- 
147-     next_refine =  BitSet ()
148-     while  true 
149-         for  i in  init_refine
150-             id =  invdeps[i]
151-             isempty (id) &&  break 
152-             union! (next_refine, id)
153-         end 
154-         intersect! (next_refine, outer_set)
155-         isempty (next_refine) &&  break 
156-         setdiff! (outer_set, next_refine)
157-         union! (inner_set, next_refine)
158- 
159-         init_refine, next_refine =  next_refine, init_refine
160-         empty! (next_refine)
161-     end 
162-     global2local =  Dict (j =>  i for  (i, j) in  enumerate (needed_assignments_idxs))
163-     inner_idxs =  [global2local[i] for  i in  collect (inner_set)]
164-     outer_idxs =  [global2local[i] for  i in  collect (outer_set)]
165-     extravars =  reduce (union!, rhsvars[inner_idxs], init =  Set ())
166-     union! (paramset, extravars)
167-     setdiff! (paramset, vars)
168-     setdiff! (paramset, [needed_assignments[i]. lhs for  i in  inner_idxs])
169-     union! (paramset, [needed_assignments[i]. lhs for  i in  outer_idxs])
170-     params =  collect (paramset)
171- 
172-     #  splatting to tighten the type
173-     u0 =  []
174-     for  v in  vars
175-         v in  keys (u0map) ||  (push! (u0, 1e-3 ); continue )
176-         u =  substitute (v, u0map)
177-         for  i in  1 : length (u0map)
178-             u =  substitute (u, u0map)
179-             u isa  Number &&  (push! (u0, u); break )
180-         end 
181-         u isa  Number ||  error (" $v  doesn't have a default." 
182-     end 
183-     u0 =  [u0... ]
184-     #  specialize on the scalar case
185-     isscalar =  length (u0) ==  1 
186-     u0 =  isscalar ?  u0[1 ] :  SVector (u0... )
187- 
188-     fname =  gensym (" fun" 
189-     #  f is the function to find roots on
190-     if  isscalar
191-         funex =  rhss[1 ]
192-         pre =  get_preprocess_constants (funex)
193-     else 
194-         funex =  MakeArray (rhss, SVector)
195-         pre =  get_preprocess_constants (rhss)
196-     end 
197-     f =  Func (
198-         [DestructuredArgs (vars, inbounds =  ! checkbounds)
199-          DestructuredArgs (params, inbounds =  ! checkbounds)],
200-         [],
201-         pre (Let (needed_assignments[inner_idxs],
202-             funex,
203-             false ))) |>  SymbolicUtils. Code. toexpr
204- 
205-     #  solver call contains code to call the root-finding solver on the function f
206-     solver_call =  LiteralExpr (quote 
207-         $ numerical_nlsolve ($ fname,
208-             #  initial guess
209-             $ u0,
210-             #  "captured variables"
211-             ($ (params... ),))
212-     end )
213- 
214-     preassignments =  []
215-     for  i in  outer_idxs
216-         ii =  needed_assignments_idxs[i]
217-         is_not_prepended_assignment[ii] ||  continue 
218-         is_not_prepended_assignment[ii] =  false 
219-         push! (preassignments, assignments[ii])
220-     end 
221- 
222-     nlsolve_expr =  Assignment[preassignments
223-                               fname ←  drop_expr (@RuntimeGeneratedFunction (f))
224-                               DestructuredArgs (vars, inbounds =  ! checkbounds) ←  solver_call]
225- 
226-     nlsolve_expr
227- end 
228- 
22999""" 
230100    find_solve_sequence(sccs, vars) 
231101
@@ -242,136 +112,3 @@ function find_solve_sequence(sccs, vars)
242112        return  find_solve_sequence (sccs, vars′)
243113    end 
244114end 
245- 
246- function  build_observed_function (state, ts, var_eq_matching, var_sccs,
247-         is_solver_unknown_idxs,
248-         assignments,
249-         deps,
250-         sol_states,
251-         var2assignment;
252-         expression =  false ,
253-         output_type =  Array,
254-         checkbounds =  true )
255-     is_not_prepended_assignment =  trues (length (assignments))
256-     if  (isscalar =  ! (ts isa  AbstractVector))
257-         ts =  [ts]
258-     end 
259-     ts =  unwrap .(Symbolics. scalarize (ts))
260- 
261-     vars =  Set ()
262-     sys =  state. sys
263-     foreach (Base. Fix1 (vars!, vars), ts)
264-     ivs =  independent_variables (sys)
265-     dep_vars =  collect (setdiff (vars, ivs))
266- 
267-     fullvars =  state. fullvars
268-     s =  state. structure
269-     unknown_vars =  fullvars[is_solver_unknown_idxs]
270-     algvars =  fullvars[.! is_solver_unknown_idxs]
271- 
272-     required_algvars =  Set (intersect (algvars, vars))
273-     obs =  observed (sys)
274-     observed_idx =  Dict (x. lhs =>  i for  (i, x) in  enumerate (obs))
275-     namespaced_to_obs =  Dict (unknowns (sys, x. lhs) =>  x. lhs for  x in  obs)
276-     namespaced_to_sts =  Dict (unknowns (sys, x) =>  x for  x in  unknowns (sys))
277-     sts =  Set (unknowns (sys))
278- 
279-     #  FIXME : This is a rather rough estimate of dependencies. We assume
280-     #  the expression depends on everything before the `maxidx`.
281-     subs =  Dict ()
282-     maxidx =  0 
283-     for  (i, s) in  enumerate (dep_vars)
284-         idx =  get (observed_idx, s, nothing )
285-         if  idx != =  nothing 
286-             idx >  maxidx &&  (maxidx =  idx)
287-         else 
288-             s′ =  get (namespaced_to_obs, s, nothing )
289-             if  s′ != =  nothing 
290-                 subs[s] =  s′
291-                 s =  s′
292-                 idx =  get (observed_idx, s, nothing )
293-             end 
294-             if  idx != =  nothing 
295-                 idx >  maxidx &&  (maxidx =  idx)
296-             elseif  ! (s in  sts)
297-                 s′ =  get (namespaced_to_sts, s, nothing )
298-                 if  s′ != =  nothing 
299-                     subs[s] =  s′
300-                     continue 
301-                 end 
302-                 throw (ArgumentError (" $s  is either an observed nor an unknown variable." 
303-             end 
304-             continue 
305-         end 
306-     end 
307-     ts =  map (t ->  substitute (t, subs), ts)
308-     vs =  Set ()
309-     for  idx in  1 : maxidx
310-         vars! (vs, obs[idx]. rhs)
311-         union! (required_algvars, intersect (algvars, vs))
312-         empty! (vs)
313-     end 
314-     for  eq in  assignments
315-         vars! (vs, eq. rhs)
316-         union! (required_algvars, intersect (algvars, vs))
317-         empty! (vs)
318-     end 
319- 
320-     varidxs =  findall (x ->  x in  required_algvars, fullvars)
321-     subset =  find_solve_sequence (var_sccs, varidxs)
322-     if  ! isempty (subset)
323-         eqs =  equations (sys)
324- 
325-         nested_torn_vars_idxs =  []
326-         for  iscc in  subset
327-             torn_vars_idxs =  Int[var
328-                                  for  var in  var_sccs[iscc]
329-                                  if  var_eq_matching[var] != =  unassigned]
330-             isempty (torn_vars_idxs) ||  push! (nested_torn_vars_idxs, torn_vars_idxs)
331-         end 
332-         torn_eqs =  [[eqs[var_eq_matching[i]] for  i in  idxs]
333-                     for  idxs in  nested_torn_vars_idxs]
334-         torn_vars =  [fullvars[idxs] for  idxs in  nested_torn_vars_idxs]
335-         u0map =  defaults (sys)
336-         assignments =  copy (assignments)
337-         solves =  map (zip (torn_eqs, torn_vars)) do  (eqs, vars)
338-             gen_nlsolve! (is_not_prepended_assignment, eqs, vars,
339-                 u0map, assignments, deps, var2assignment;
340-                 checkbounds =  checkbounds)
341-         end 
342-     else 
343-         solves =  []
344-     end 
345- 
346-     subs =  []
347-     for  sym in  vars
348-         eqidx =  get (observed_idx, sym, nothing )
349-         eqidx ===  nothing  &&  continue 
350-         push! (subs, sym ←  obs[eqidx]. rhs)
351-     end 
352-     pre =  get_postprocess_fbody (sys)
353-     cpre =  get_preprocess_constants ([obs[1 : maxidx];
354-                                      isscalar ?  ts[1 ] :  MakeArray (ts, output_type)])
355-     pre2 =  x ->  pre (cpre (x))
356-     ex =  Code. toexpr (
357-         Func (
358-             [DestructuredArgs (unknown_vars, inbounds =  ! checkbounds)
359-              DestructuredArgs (parameters (sys), inbounds =  ! checkbounds)
360-              independent_variables (sys)],
361-             [],
362-             pre2 (Let (
363-                 [collect (Iterators. flatten (solves))
364-                  assignments[is_not_prepended_assignment]
365-                  map (eq ->  eq. lhs ←  eq. rhs, obs[1 : maxidx])
366-                  subs],
367-                 isscalar ?  ts[1 ] :  MakeArray (ts, output_type),
368-                 false ))),
369-         sol_states)
370- 
371-     expression ?  ex :  drop_expr (@RuntimeGeneratedFunction (ex))
372- end 
373- 
374- struct  ODAEProblem{iip} end 
375- 
376- @deprecate  ODAEProblem (args... ; kw... ) ODEProblem (args... ; kw... )
377- @deprecate  ODAEProblem {iip} (args... ; kw... ) where  {iip} ODEProblem {iip} (args... ; kw... )
0 commit comments