@@ -254,44 +254,60 @@ function Base.push!(ev::EquationsView, eq)
254254    push! (ev. ts. extra_eqs, eq)
255255end 
256256
257- function  is_time_dependent_parameter (p, iv)
258-     return  iv != =  nothing  &&  isparameter (p)  &&  iscall (p) && 
259-            (operation (p) ===  getindex &&  is_time_dependent_parameter (arguments (p)[1 ], iv) || 
257+ function  is_time_dependent_parameter (p, allps,  iv)
258+     return  iv != =  nothing  &&  p  in  allps  &&  iscall (p) && 
259+            (operation (p) ===  getindex &&  is_time_dependent_parameter (arguments (p)[1 ], allps,  iv) || 
260260            (args =  arguments (p); length (args)) ==  1  &&  isequal (only (args), iv))
261261end 
262262
263+ function  symbolic_contains (var, set)
264+     var in  set ||  symbolic_type (var) ==  ArraySymbolic () &&  Symbolics. shape (var) !=  Symbolics. Unknown () &&  all (i ->  var[i] in  set, eachindex (var))
265+ end 
266+ 
263267function  TearingState (sys; quick_cancel =  false , check =  true , sort_eqs =  true )
268+     #  flatten system
264269    sys =  flatten (sys)
265270    ivs =  independent_variables (sys)
266271    iv =  length (ivs) ==  1  ?  ivs[1 ] :  nothing 
267-     #  scalarize  array equations, without scalarizing arguments to registered functions 
268-     eqs =  flatten_equations (copy ( equations (sys) ))
272+     #  flatten  array equations
273+     eqs =  flatten_equations (equations (sys))
269274    neqs =  length (eqs)
270-     dervaridxs =  OrderedSet {Int} ()
271-     var2idx =  Dict {Any, Int} ()
272-     symbolic_incidence =  []
273-     fullvars =  []
274275    param_derivative_map =  Dict {BasicSymbolic, Any} ()
275-     var_counter =  Ref (0 )
276-     var_types =  VariableType[]
277-     addvar! =  let  fullvars =  fullvars, var_counter =  var_counter, var_types =  var_types
276+     #  * Scalarize unknowns
277+     dvs =  Set {BasicSymbolic} ()
278+     fullvars =  BasicSymbolic[]
279+     for  x in  unknowns (sys)
280+         push! (dvs, x)
281+         xx =  Symbolics. scalarize (x)
282+         if  xx isa  AbstractArray
283+             union! (dvs, xx)
284+             append! (fullvars, xx)
285+         else 
286+             push! (fullvars, xx)
287+         end 
288+     end 
289+     ps =  Set {BasicSymbolic} ()
290+     for  x in  parameters (sys)
291+         push! (ps, x)
292+         xx =  Symbolics. scalarize (x)
293+         xx isa  AbstractArray &&  union! (dvs, x)
294+     end 
295+     var2idx =  Dict {BasicSymbolic, Int} (v =>  k for  (k, v) in  enumerate (fullvars))
296+     addvar! =  let  fullvars =  fullvars, dvs =  dvs, var2idx =  var2idx
278297        var ->  get! (var2idx, var) do 
298+             push! (dvs, var)
279299            push! (fullvars, var)
280-             push! (var_types, getvariabletype (var))
281-             var_counter[] +=  1 
300+             return  length (fullvars)
282301        end 
283302    end 
284303
285-     vars =  OrderedSet ()
286-     varsvec =  []
304+     #  build symbolic incidence
305+     symbolic_incidence =  Vector{BasicSymbolic}[]
306+     varsbuf =  Set ()
287307    eqs_to_retain =  trues (length (eqs))
288-     for  (i, eq′) in  enumerate (eqs)
289-         if  eq′. lhs isa  Connection
290-             check ?  error (" $(nameof (sys))  has unexpanded `connect` statements" : 
291-             return  nothing 
292-         end 
308+     for  (i, eq) in  enumerate (eqs)
293309        if  iscall (eq′. lhs) &&  (op =  operation (eq′. lhs)) isa  Differential && 
294-            isequal (op. x, iv) &&  is_time_dependent_parameter (only (arguments (eq′. lhs)), iv)
310+            isequal (op. x, iv) &&  is_time_dependent_parameter (only (arguments (eq′. lhs)), ps,  iv)
295311            #  parameter derivatives are opted out by specifying `D(p) ~ missing`, but
296312            #  we want to store `nothing` in the map because that means `fast_substitute`
297313            #  will ignore the rule. We will this identify the presence of `eq′.lhs` in
@@ -301,80 +317,71 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
301317            #  change the equation if the RHS is `missing` so the rest of this loop works
302318            eq′ =  eq′. lhs ~  coalesce (eq′. rhs, 0.0 )
303319        end 
304-         if  _iszero (eq′. lhs)
305-             rhs =  quick_cancel ?  quick_cancel_expr (eq′. rhs) :  eq′. rhs
306-             eq =  eq′
307-         else 
308-             lhs =  quick_cancel ?  quick_cancel_expr (eq′. lhs) :  eq′. lhs
309-             rhs =  quick_cancel ?  quick_cancel_expr (eq′. rhs) :  eq′. rhs
310-             eq =  0  ~  rhs -  lhs
320+         rhs =  quick_cancel ?  quick_cancel_expr (eq. rhs) :  eq. rhs
321+         if  ! _iszero (eq. lhs)
322+             lhs =  quick_cancel ?  quick_cancel_expr (eq. lhs) :  eq. lhs
323+             eq =  eqs[i] =  0  ~  rhs -  lhs
311324        end 
312-         vars! (vars, eq. rhs, op =  Symbolics. Operator)
313-         for  v in  vars
314-             _var, _ =  var_from_nested_derivative (v)
315-             any (isequal (_var), ivs) &&  continue 
316-             if  isparameter (_var) || 
317-                (iscall (_var) &&  isparameter (operation (_var)) ||  isconstant (_var))
318-                 if  is_time_dependent_parameter (_var, iv) && 
319-                    ! haskey (param_derivative_map, Differential (iv)(_var))
325+         empty! (varsbuf)
326+         vars! (varsbuf, eq; op =  Symbolics. Operator)
327+         incidence =  Set {BasicSymbolic} ()
328+         for  v in  varsbuf
329+             #  FIXME : This check still needs to rely on metadata
330+             isconstant (v) &&  continue 
331+             vtype =  getvariabletype (v)
332+             #  additionally track brownians in fullvars
333+             #  TODO : When uniting system types, track brownians in their own field
334+             if  vtype ==  BROWNIAN
335+                 i =  addvar! (v)
336+                 push! (incidence, v)
337+             end 
338+ 
339+             if  symbolic_contains (v, ps)
340+                 if  is_time_dependent_parameter (v, ps, iv) &&  ! haskey (param_derivative_map, Differential (iv)(_var))
320341                    #  Parameter derivatives default to zero - they stay constant
321342                    #  between callbacks
322343                    param_derivative_map[Differential (iv)(_var)] =  0.0 
323344                end 
324345                continue 
325346            end 
326-             v =  scalarize (v)
327-             if  v isa  AbstractArray
328-                 append! (varsvec, v)
329-             else 
330-                 push! (varsvec, v)
331-             end 
332-         end 
333-         isalgeq =  true 
334-         unknownvars =  []
335-         for  var in  varsvec
336-             ModelingToolkit. isdelay (var, iv) &&  continue 
337-             set_incidence =  true 
338-             @label  ANOTHER_VAR
339-             _var, _ =  var_from_nested_derivative (var)
340-             any (isequal (_var), ivs) &&  continue 
341-             if  isparameter (_var) || 
342-                (iscall (_var) &&  isparameter (operation (_var)) ||  isconstant (_var))
343-                 continue 
344-             end 
345-             varidx =  addvar! (var)
346-             set_incidence &&  push! (unknownvars, var)
347- 
348-             dvar =  var
349-             idx =  varidx
350-             while  isdifferential (dvar)
351-                 if  ! (idx in  dervaridxs)
352-                     push! (dervaridxs, idx)
347+ 
348+             if  ! symbolic_contains (v, dvs)
349+                 isvalid =  iscall (v) &&  operation (v) isa  Union{Shift, Sample, Hold}
350+                 v′ =  v
351+                 while  ! isvalid &&  iscall (v′) &&  operation (v′) isa  Union{Differential, Shift}
352+                     v′ =  arguments (v)[1 ]
353+                     if  v′ in  dvs ||  getmetadata (v′, SymScope, LocalScope ()) isa  GlobalScope
354+                         isvalid =  true 
355+                         break 
356+                     end 
357+                 end 
358+                 if  ! isvalid
359+                     throw (ArgumentError (" $v  is present in the system but $v′  is not an unknown." 
353360                end 
354-                 isalgeq =  false 
355-                 dvar =  arguments (dvar)[1 ]
356-                 idx =  addvar! (dvar)
357-             end 
358361
359-             dvar =  var
360-             idx =  varidx
362+                 addvar! (v)
363+                 if  iscall (v) &&  operation (v) isa  Symbolics. Operator &&  ! isdifferential (v) &&  (it =  input_timedomain (v)) != =  nothing 
364+                     v′ =  only (arguments (v))
365+                     addvar! (setmetadata (v′, VariableTimeDomain, it))
366+                 end 
367+             end 
361368
362-             if  iscall (var) &&  operation (var) isa  Symbolics. Operator && 
363-                ! isdifferential (var) &&  (it =  input_timedomain (var)) != =  nothing 
364-                 set_incidence =  false 
365-                 var =  only (arguments (var))
366-                 var =  setmetadata (var, VariableTimeDomain, it)
367-                 @goto  ANOTHER_VAR
369+             if  symbolic_type (v) ==  ArraySymbolic ()
370+                 union! (incidence, collect (v))
371+             else 
372+                 push! (incidence, v)
368373            end 
369374        end 
370-         push! (symbolic_incidence, copy (unknownvars))
371-         empty! (unknownvars)
372-         empty! (vars)
373-         empty! (varsvec)
374-         if  isalgeq
375-             eqs[i] =  eq
376-         else 
377-             eqs[i] =  eqs[i]. lhs ~  rhs
375+ 
376+         push! (symbolic_incidence, collect (incidence))
377+     end 
378+ 
379+     dervaridxs =  Int[]
380+     for  (i, v) in  enumerate (fullvars)
381+         while  isdifferential (v)
382+             push! (dervaridxs, i)
383+             v =  arguments (v)[1 ]
384+             i =  addvar! (v)
378385        end 
379386    end 
380387    eqs =  eqs[eqs_to_retain]
@@ -389,6 +396,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
389396        symbolic_incidence =  symbolic_incidence[sortidxs]
390397    end 
391398
399+     #  Handle shifts - find lowest shift and add intermediates with derivative edges
392400    # ## Handle discrete variables
393401    lowest_shift =  Dict ()
394402    for  var in  fullvars
@@ -428,6 +436,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
428436            end 
429437        end 
430438    end 
439+ 
440+     var_types =  Vector {VariableType} (getvariabletype .(fullvars))
441+ 
431442    #  sort `fullvars` such that the mass matrix is as diagonal as possible.
432443    dervaridxs =  collect (dervaridxs)
433444    sorted_fullvars =  OrderedSet (fullvars[dervaridxs])
@@ -451,6 +462,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
451462    var2idx =  Dict (fullvars .=>  eachindex (fullvars))
452463    dervaridxs =  1 : length (dervaridxs)
453464
465+     #  build `var_to_diff`
454466    nvars =  length (fullvars)
455467    diffvars =  []
456468    var_to_diff =  DiffGraph (nvars, true )
@@ -462,13 +474,15 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
462474        var_to_diff[diffvaridx] =  dervaridx
463475    end 
464476
477+     #  build incidence graph
465478    graph =  BipartiteGraph (neqs, nvars, Val (false ))
466479    for  (ie, vars) in  enumerate (symbolic_incidence), v in  vars
467480        jv =  var2idx[v]
468481        add_edge! (graph, ie, jv)
469482    end 
470483
471484    @set!  sys. eqs =  eqs
485+     @set!  sys. unknowns =  [v for  (i, v) in  enumerate (fullvars) if  var_types[i] !=  BROWNIAN]
472486
473487    eq_to_diff =  DiffGraph (nsrcs (graph))
474488
@@ -731,3 +745,19 @@ function _structural_simplify!(state::TearingState; simplify = false,
731745
732746    ModelingToolkit. invalidate_cache! (sys)
733747end 
748+ 
749+ struct  DifferentiatedVariableNotUnknownError <:  Exception 
750+     differentiated
751+     undifferentiated
752+ end 
753+ 
754+ function  Base. showerror (io:: IO , err:: DifferentiatedVariableNotUnknownError )
755+     undiff =  err. undifferentiated
756+     diff =  err. differentiated
757+     print (io, " Variable $undiff  occurs differentiated as $diff  but is not an unknown of the system." 
758+     scope =  getmetadata (undiff, SymScope, LocalScope ())
759+     depth =  expected_scope_depth (scope)
760+     if  depth >  0 
761+         print (io, " \n Variable $undiff  expects $depth  more levels in the hierarchy to be an unknown." 
762+     end 
763+ end 
0 commit comments