230230=# 
231231
232232function  tearing_reassemble (state:: TearingState , var_eq_matching,
233-         full_var_eq_matching =  nothing ; simplify =  false , mm =  nothing )
233+         full_var_eq_matching =  nothing ; simplify =  false , mm =  nothing , cse_hack  =   true , array_hack  =   true )
234234    @unpack  fullvars, sys, structure =  state
235235    @unpack  solvable_graph, var_to_diff, eq_to_diff, graph =  structure
236236    extra_vars =  Int[]
@@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574    #  TODO : compute the dependency correctly so that we don't have to do this
575575    obs =  [fast_substitute (observed (sys), obs_sub); subeqs]
576576
577-     #  HACK: Substitute non-scalarized symbolic arrays of observed variables
578-     #  E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579-     #  ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580-     #  by the topological sorting and dependency identification pieces
581-     obs_arr_subs =  Dict ()
582- 
583-     for  eq in  obs
584-         lhs =  eq. lhs
585-         iscall (lhs) ||  continue 
586-         operation (lhs) ===  getindex ||  continue 
587-         Symbolics. shape (lhs) != =  Symbolics. Unknown () ||  continue 
588-         arg1 =  arguments (lhs)[1 ]
589-         haskey (obs_arr_subs, arg1) &&  continue 
590-         obs_arr_subs[arg1] =  [arg1[i] for  i in  eachindex (arg1)] #  e.g. p => [p[1], p[2]]
591-         index_first =  eachindex (arg1)[1 ]
592- 
593-         #  respect non-1-indexed arrays
594-         #  TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595-         obs_arr_subs[arg1] =  Origin (index_first)(obs_arr_subs[arg1])
596-     end 
597-     for  i in  eachindex (neweqs)
598-         neweqs[i] =  fast_substitute (neweqs[i], obs_arr_subs; operator =  Symbolics. Operator)
599-     end 
600-     for  i in  eachindex (obs)
601-         obs[i] =  fast_substitute (obs[i], obs_arr_subs; operator =  Symbolics. Operator)
602-     end 
603-     for  i in  eachindex (subeqs)
604-         subeqs[i] =  fast_substitute (subeqs[i], obs_arr_subs; operator =  Symbolics. Operator)
605-     end 
606- 
607-     @set!  sys. eqs =  neweqs
608-     @set!  sys. observed =  obs
609- 
610577    unknowns =  Any[v
611578                   for  (i, v) in  enumerate (fullvars)
612579                   if  diff_to_var[i] ===  nothing  &&  ispresent (i)]
@@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
616583        end 
617584    end 
618585    @set!  sys. unknowns =  unknowns
586+ 
587+     obs, subeqs, deps =  cse_and_array_hacks (
588+         obs, subeqs, unknowns, neweqs; cse =  cse_hack, array =  array_hack)
589+ 
590+     @set!  sys. eqs =  neweqs
591+     @set!  sys. observed =  obs
592+ 
619593    @set!  sys. substitutions =  Substitutions (subeqs, deps)
620594
621595    #  Only makes sense for time-dependent
@@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629603    return  invalidate_cache! (sys)
630604end 
631605
606+ """ 
607+ # HACK 1 
608+ 
609+ Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]` 
610+ gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets 
611+ _very_ expensive. this hack performs a limited form of CSE specifically for this case to 
612+ avoid the unnecessary cost. This and the below hack are implemented simultaneously 
613+ 
614+ # HACK 2 
615+ 
616+ Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an 
617+ equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation 
618+ if all `p[i]` are present and the unscalarized form is used in any equation (observed or 
619+ not) we first count the number of times the scalarized form of each observed variable 
620+ occurs in observed equations (and unknowns if it's split). 
621+ """ 
622+ function  cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse =  true , array =  true )
623+     #  HACK 1
624+     #  mapping of rhs to temporary CSE variable
625+     #  `f(...) => tmpvar` in above example
626+     rhs_to_tempvar =  Dict ()
627+ 
628+     #  HACK 2
629+     #  map of array observed variable (unscalarized) to number of its
630+     #  scalarized terms that appear in observed equations
631+     arr_obs_occurrences =  Dict ()
632+     #  to check if array variables occur in unscalarized form anywhere
633+     all_vars =  Set ()
634+     for  (i, eq) in  enumerate (obs)
635+         lhs =  eq. lhs
636+         rhs =  eq. rhs
637+         vars! (all_vars, rhs)
638+ 
639+         #  HACK 1
640+         if  cse &&  is_getindexed_array (rhs)
641+             rhs_arr =  arguments (rhs)[1 ]
642+             if  ! haskey (rhs_to_tempvar, rhs_arr)
643+                 tempvar =  gensym (Symbol (lhs))
644+                 N =  length (rhs_arr)
645+                 tempvar =  unwrap (Symbolics. variable (
646+                     tempvar; T =  Symbolics. symtype (rhs_arr)))
647+                 tempvar =  setmetadata (
648+                     tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
649+                 tempeq =  tempvar ~  rhs_arr
650+                 rhs_to_tempvar[rhs_arr] =  tempvar
651+                 push! (obs, tempeq)
652+                 push! (subeqs, tempeq)
653+             end 
654+ 
655+             #  getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
656+             #  so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
657+             #  which fails the topological sort
658+             neweq =  lhs ~  getindex_wrapper (
659+                 rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
660+             obs[i] =  neweq
661+             subeqi =  findfirst (isequal (eq), subeqs)
662+             if  subeqi != =  nothing 
663+                 subeqs[subeqi] =  neweq
664+             end 
665+         end 
666+         #  end HACK 1
667+ 
668+         array ||  continue 
669+         iscall (lhs) ||  continue 
670+         operation (lhs) ===  getindex ||  continue 
671+         Symbolics. shape (lhs) !=  Symbolics. Unknown () ||  continue 
672+         arg1 =  arguments (lhs)[1 ]
673+         cnt =  get (arr_obs_occurrences, arg1, 0 )
674+         arr_obs_occurrences[arg1] =  cnt +  1 
675+         continue 
676+     end 
677+ 
678+     #  Also do CSE for `equations(sys)`
679+     if  cse
680+         for  (i, eq) in  enumerate (neweqs)
681+             (; lhs, rhs) =  eq
682+             is_getindexed_array (rhs) ||  continue 
683+             rhs_arr =  arguments (rhs)[1 ]
684+             if  ! haskey (rhs_to_tempvar, rhs_arr)
685+                 tempvar =  gensym (Symbol (lhs))
686+                 N =  length (rhs_arr)
687+                 tempvar =  unwrap (Symbolics. variable (
688+                     tempvar; T =  Symbolics. symtype (rhs_arr)))
689+                 tempvar =  setmetadata (
690+                     tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
691+                 tempeq =  tempvar ~  rhs_arr
692+                 rhs_to_tempvar[rhs_arr] =  tempvar
693+                 push! (obs, tempeq)
694+                 push! (subeqs, tempeq)
695+             end 
696+             #  don't need getindex_wrapper, but do it anyway to know that this
697+             #  hack took place
698+             neweq =  lhs ~  getindex_wrapper (
699+                 rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
700+             neweqs[i] =  neweq
701+         end 
702+     end 
703+ 
704+     #  count variables in unknowns if they are scalarized forms of variables
705+     #  also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
706+     #  is an observed equation.
707+     for  sym in  unknowns
708+         iscall (sym) ||  continue 
709+         operation (sym) ===  getindex ||  continue 
710+         Symbolics. shape (sym) !=  Symbolics. Unknown () ||  continue 
711+         arg1 =  arguments (sym)[1 ]
712+         cnt =  get (arr_obs_occurrences, arg1, 0 )
713+         cnt ==  0  &&  continue 
714+         arr_obs_occurrences[arg1] =  cnt +  1 
715+     end 
716+     for  eq in  neweqs
717+         vars! (all_vars, eq. rhs)
718+     end 
719+     obs_arr_eqs =  Equation[]
720+     for  (arrvar, cnt) in  arr_obs_occurrences
721+         cnt ==  length (arrvar) ||  continue 
722+         arrvar in  all_vars ||  continue 
723+         #  firstindex returns 1 for multidimensional array symbolics
724+         firstind =  first (eachindex (arrvar))
725+         scal =  [arrvar[i] for  i in  eachindex (arrvar)]
726+         #  respect non-1-indexed arrays
727+         #  TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
728+         #  `change_origin` is required because `Origin(firstind)(scal)` makes codegen
729+         #  try to `create_array(OffsetArray{...}, ...)` which errors.
730+         #  `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
731+         #  of `scal`.
732+         push! (obs_arr_eqs, arrvar ~  change_origin (Origin (firstind), scal))
733+     end 
734+     append! (obs, obs_arr_eqs)
735+     append! (subeqs, obs_arr_eqs)
736+ 
737+     #  need to re-sort subeqs
738+     subeqs =  ModelingToolkit. topsort_equations (subeqs, [eq. lhs for  eq in  subeqs])
739+ 
740+     deps =  Vector{Int}[i ==  1  ?  Int[] :  collect (1 : (i -  1 ))
741+                        for  i in  1 : length (subeqs)]
742+ 
743+     return  obs, subeqs, deps
744+ end 
745+ 
746+ function  is_getindexed_array (rhs)
747+     (! ModelingToolkit. isvariable (rhs) ||  ModelingToolkit. iscalledparameter (rhs)) && 
748+         iscall (rhs) &&  operation (rhs) ===  getindex && 
749+         Symbolics. shape (rhs) !=  Symbolics. Unknown ()
750+ end 
751+ 
752+ #  PART OF HACK 1
753+ getindex_wrapper (x, i) =  x[i... ]
754+ 
755+ @register_symbolic  getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
756+ 
757+ #  PART OF HACK 2
758+ function  change_origin (origin, arr)
759+     return  origin (arr)
760+ end 
761+ 
762+ @register_array_symbolic  change_origin (origin:: Origin , arr:: AbstractArray ) begin 
763+     size =  size (arr)
764+     eltype =  eltype (arr)
765+     ndims =  ndims (arr)
766+ end 
767+ 
632768function  tearing (state:: TearingState ; kwargs... )
633769    state. structure. solvable_graph ===  nothing  &&  find_solvables! (state; kwargs... )
634770    complete! (state. structure)
@@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
643779instead, which calls this function internally. 
644780""" 
645781function  tearing (sys:: AbstractSystem , state =  TearingState (sys); mm =  nothing ,
646-         simplify =  false , kwargs... )
782+         simplify =  false , cse_hack  =   true , array_hack  =   true ,  kwargs... )
647783    var_eq_matching, full_var_eq_matching =  tearing (state)
648784    invalidate_cache! (tearing_reassemble (
649-         state, var_eq_matching, full_var_eq_matching; mm, simplify))
785+         state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack ))
650786end 
651787
652788""" 
@@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
668804the system is balanced. 
669805""" 
670806function  dummy_derivative (sys, state =  TearingState (sys); simplify =  false ,
671-         mm =  nothing , kwargs... )
807+         mm =  nothing , cse_hack  =   true , array_hack  =   true ,  kwargs... )
672808    jac =  let  state =  state
673809        (eqs, vars) ->  begin 
674810            symeqs =  EquationsView (state)[eqs]
@@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
692828    end 
693829    var_eq_matching =  dummy_derivative_graph! (state, jac; state_priority,
694830        kwargs... )
695-     tearing_reassemble (state, var_eq_matching; simplify, mm)
831+     tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack )
696832end 
0 commit comments