@@ -253,105 +253,106 @@ function Base.push!(ev::EquationsView, eq)
253253    push! (ev. ts. extra_eqs, eq)
254254end 
255255
256+ function  symbolic_contains (var, set)
257+     var in  set ||  symbolic_type (var) ==  ArraySymbolic () &&  Symbolics. shape (var) !=  Symbolics. Unknown () &&  all (i ->  var[i] in  set, eachindex (var))
258+ end 
259+ 
256260function  TearingState (sys; quick_cancel =  false , check =  true )
261+     #  flatten system
257262    sys =  flatten (sys)
258263    ivs =  independent_variables (sys)
259264    iv =  length (ivs) ==  1  ?  ivs[1 ] :  nothing 
260-     #  scalarize  array equations, without scalarizing arguments to registered functions 
261-     eqs =  flatten_equations (copy ( equations (sys) ))
265+     #  flatten  array equations
266+     eqs =  flatten_equations (equations (sys))
262267    neqs =  length (eqs)
263-     dervaridxs =  OrderedSet {Int} ()
264-     var2idx =  Dict {Any, Int} ()
265-     symbolic_incidence =  []
266-     fullvars =  []
267-     var_counter =  Ref (0 )
268-     var_types =  VariableType[]
269-     addvar! =  let  fullvars =  fullvars, var_counter =  var_counter, var_types =  var_types
268+     #  * Scalarize unknowns
269+     dvs =  Set {BasicSymbolic} ()
270+     fullvars =  BasicSymbolic[]
271+     for  x in  unknowns (sys)
272+         push! (dvs, x)
273+         xx =  Symbolics. scalarize (x)
274+         if  xx isa  AbstractArray
275+             union! (dvs, xx)
276+             append! (fullvars, xx)
277+         else 
278+             push! (fullvars, xx)
279+         end 
280+     end 
281+     var2idx =  Dict {BasicSymbolic, Int} (v =>  k for  (k, v) in  enumerate (fullvars))
282+     addvar! =  let  fullvars =  fullvars, dvs =  dvs, var2idx =  var2idx
270283        var ->  get! (var2idx, var) do 
284+             push! (dvs, var)
271285            push! (fullvars, var)
272-             push! (var_types, getvariabletype (var))
273-             var_counter[] +=  1 
286+             return  length (fullvars)
274287        end 
275288    end 
276289
277-     vars =  OrderedSet ()
278-     varsvec =  []
279-     for  (i, eq′) in  enumerate (eqs)
280-         if  eq′. lhs isa  Connection
281-             check ?  error (" $(nameof (sys))  has unexpanded `connect` statements" : 
282-             return  nothing 
283-         end 
284-         if  _iszero (eq′. lhs)
285-             rhs =  quick_cancel ?  quick_cancel_expr (eq′. rhs) :  eq′. rhs
286-             eq =  eq′
287-         else 
288-             lhs =  quick_cancel ?  quick_cancel_expr (eq′. lhs) :  eq′. lhs
289-             rhs =  quick_cancel ?  quick_cancel_expr (eq′. rhs) :  eq′. rhs
290-             eq =  0  ~  rhs -  lhs
290+     #  build symbolic incidence
291+     symbolic_incidence =  Vector{BasicSymbolic}[]
292+     varsbuf =  Set ()
293+     for  (i, eq) in  enumerate (eqs)
294+         rhs =  quick_cancel ?  quick_cancel_expr (eq. rhs) :  eq. rhs
295+         if  ! _iszero (eq. lhs)
296+             lhs =  quick_cancel ?  quick_cancel_expr (eq. lhs) :  eq. lhs
297+             eq =  eqs[i] =  0  ~  rhs -  lhs
291298        end 
292-         vars! (vars, eq. rhs, op =  Symbolics. Operator)
293-         for  v in  vars
294-             _var, _ =  var_from_nested_derivative (v)
295-             any (isequal (_var), ivs) &&  continue 
296-             if  isparameter (_var) || 
297-                (iscall (_var) &&  isparameter (operation (_var)) ||  isconstant (_var))
298-                 continue 
299+         empty! (varsbuf)
300+         vars! (varsbuf, eq; op =  Symbolics. Operator)
301+         incidence =  Set {BasicSymbolic} ()
302+         for  v in  varsbuf
303+             #  FIXME : This check still needs to rely on metadata
304+             isconstant (v) &&  continue 
305+             vtype =  getvariabletype (v)
306+             #  additionally track brownians in fullvars
307+             #  TODO : When uniting system types, track brownians in their own field
308+             if  vtype ==  BROWNIAN
309+                 i =  addvar! (v)
310+                 push! (incidence, v)
299311            end 
300-             v =  scalarize (v)
301-             if  v isa  AbstractArray
302-                 append! (varsvec, v)
303-             else 
304-                 push! (varsvec, v)
305-             end 
306-         end 
307-         isalgeq =  true 
308-         unknownvars =  []
309-         for  var in  varsvec
310-             ModelingToolkit. isdelay (var, iv) &&  continue 
311-             set_incidence =  true 
312-             @label  ANOTHER_VAR
313-             _var, _ =  var_from_nested_derivative (var)
314-             any (isequal (_var), ivs) &&  continue 
315-             if  isparameter (_var) || 
316-                (iscall (_var) &&  isparameter (operation (_var)) ||  isconstant (_var))
317-                 continue 
318-             end 
319-             varidx =  addvar! (var)
320-             set_incidence &&  push! (unknownvars, var)
321- 
322-             dvar =  var
323-             idx =  varidx
324-             while  isdifferential (dvar)
325-                 if  ! (idx in  dervaridxs)
326-                     push! (dervaridxs, idx)
312+ 
313+             vtype ==  VARIABLE ||  continue 
314+ 
315+             if  ! symbolic_contains (v, dvs)
316+                 isvalid =  iscall (v) &&  operation (v) isa  Union{Shift, Sample, Hold}
317+                 v′ =  v
318+                 while  ! isvalid &&  iscall (v′) &&  operation (v′) isa  Union{Differential, Shift}
319+                     v′ =  arguments (v)[1 ]
320+                     if  v′ in  dvs ||  getmetadata (v′, SymScope, LocalScope ()) isa  GlobalScope
321+                         isvalid =  true 
322+                         break 
323+                     end 
324+                 end 
325+                 if  ! isvalid
326+                     throw (ArgumentError (" $v  is present in the system but $v′  is not an unknown." 
327327                end 
328-                 isalgeq =  false 
329-                 dvar =  arguments (dvar)[1 ]
330-                 idx =  addvar! (dvar)
331-             end 
332328
333-             dvar =  var
334-             idx =  varidx
329+                 addvar! (v)
330+                 if  iscall (v) &&  operation (v) isa  Symbolics. Operator &&  ! isdifferential (v) &&  (it =  input_timedomain (v)) != =  nothing 
331+                     v′ =  only (arguments (v))
332+                     addvar! (setmetadata (v′, VariableTimeDomain, it))
333+                 end 
334+             end 
335335
336-             if  iscall (var) &&  operation (var) isa  Symbolics. Operator && 
337-                ! isdifferential (var) &&  (it =  input_timedomain (var)) != =  nothing 
338-                 set_incidence =  false 
339-                 var =  only (arguments (var))
340-                 var =  setmetadata (var, VariableTimeDomain, it)
341-                 @goto  ANOTHER_VAR
336+             if  symbolic_type (v) ==  ArraySymbolic ()
337+                 union! (incidence, collect (v))
338+             else 
339+                 push! (incidence, v)
342340            end 
343341        end 
344-         push! (symbolic_incidence, copy (unknownvars))
345-         empty! (unknownvars)
346-         empty! (vars)
347-         empty! (varsvec)
348-         if  isalgeq
349-             eqs[i] =  eq
350-         else 
351-             eqs[i] =  eqs[i]. lhs ~  rhs
342+ 
343+         push! (symbolic_incidence, collect (incidence))
344+     end 
345+ 
346+     dervaridxs =  Int[]
347+     for  (i, v) in  enumerate (fullvars)
348+         while  isdifferential (v)
349+             push! (dervaridxs, i)
350+             v =  arguments (v)[1 ]
351+             i =  addvar! (v)
352352        end 
353353    end 
354354
355+     #  Handle shifts - find lowest shift and add intermediates with derivative edges
355356    # ## Handle discrete variables
356357    lowest_shift =  Dict ()
357358    for  var in  fullvars
@@ -391,6 +392,9 @@ function TearingState(sys; quick_cancel = false, check = true)
391392            end 
392393        end 
393394    end 
395+ 
396+     var_types =  Vector {VariableType} (getvariabletype .(fullvars))
397+ 
394398    #  sort `fullvars` such that the mass matrix is as diagonal as possible.
395399    dervaridxs =  collect (dervaridxs)
396400    sorted_fullvars =  OrderedSet (fullvars[dervaridxs])
@@ -414,6 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true)
414418    var2idx =  Dict (fullvars .=>  eachindex (fullvars))
415419    dervaridxs =  1 : length (dervaridxs)
416420
421+     #  build `var_to_diff`
417422    nvars =  length (fullvars)
418423    diffvars =  []
419424    var_to_diff =  DiffGraph (nvars, true )
@@ -425,20 +430,24 @@ function TearingState(sys; quick_cancel = false, check = true)
425430        var_to_diff[diffvaridx] =  dervaridx
426431    end 
427432
433+     #  build incidence graph
428434    graph =  BipartiteGraph (neqs, nvars, Val (false ))
429435    for  (ie, vars) in  enumerate (symbolic_incidence), v in  vars
430436        jv =  var2idx[v]
431437        add_edge! (graph, ie, jv)
432438    end 
433439
434440    @set!  sys. eqs =  eqs
441+     @set!  sys. unknowns =  [v for  (i, v) in  enumerate (fullvars) if  var_types[i] !=  BROWNIAN]
435442
436443    eq_to_diff =  DiffGraph (nsrcs (graph))
437444
438445    ts =  TearingState (sys, fullvars,
439446        SystemStructure (complete (var_to_diff), complete (eq_to_diff),
440447            complete (graph), nothing , var_types, sys isa  AbstractDiscreteSystem),
441448        Any[])
449+ 
450+     #  `shift_discrete_system`
442451    if  sys isa  DiscreteSystem
443452        ts =  shift_discrete_system (ts)
444453    end 
@@ -726,3 +735,19 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
726735
727736    ModelingToolkit. invalidate_cache! (sys), input_idxs
728737end 
738+ 
739+ struct  DifferentiatedVariableNotUnknownError <:  Exception 
740+     differentiated
741+     undifferentiated
742+ end 
743+ 
744+ function  Base. showerror (io:: IO , err:: DifferentiatedVariableNotUnknownError )
745+     undiff =  err. undifferentiated
746+     diff =  err. differentiated
747+     print (io, " Variable $undiff  occurs differentiated as $diff  but is not an unknown of the system." 
748+     scope =  getmetadata (undiff, SymScope, LocalScope ())
749+     depth =  expected_scope_depth (scope)
750+     if  depth >  0 
751+         print (io, " \n Variable $undiff  expects $depth  more levels in the hierarchy to be an unknown." 
752+     end 
753+ end 
0 commit comments