@@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584 end
585585 @set! sys. unknowns = unknowns
586586
587+ # HACK: Since we don't support array equations, any equation of the sort
588+ # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589+ # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590+ # for this case to avoid the unnecessary cost.
591+ # This and the below hack are implemented simultaneously
592+
593+ # mapping of rhs to temporary CSE variable
594+ # `f(...) => tmpvar` in above example
595+ rhs_to_tempvar = Dict ()
596+
587597 # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588598 # are equations, add an equation `p ~ [p[1], p[2], ...]`
589599 # allow topsort to reorder them
@@ -597,9 +607,42 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
597607 arr_obs_occurrences = Dict ()
598608 # to check if array variables occur in unscalarized form anywhere
599609 all_vars = Set ()
600- for eq in obs
601- vars! (all_vars, eq. rhs)
610+ for (i, eq) in enumerate (obs)
602611 lhs = eq. lhs
612+ rhs = eq. rhs
613+ vars! (all_vars, rhs)
614+
615+ # HACK 1
616+ if (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
617+ iscall (rhs) && operation (rhs) === getindex &&
618+ Symbolics. shape (rhs) != Symbolics. Unknown ()
619+ rhs_arr = arguments (rhs)[1 ]
620+ if ! haskey (rhs_to_tempvar, rhs_arr)
621+ tempvar = gensym (Symbol (lhs))
622+ N = length (rhs_arr)
623+ tempvar = unwrap (Symbolics. variable (
624+ tempvar; T = Symbolics. symtype (rhs_arr)))
625+ tempvar = setmetadata (
626+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
627+ tempeq = tempvar ~ rhs_arr
628+ rhs_to_tempvar[rhs_arr] = tempvar
629+ push! (obs, tempeq)
630+ push! (subeqs, tempeq)
631+ end
632+
633+ # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
634+ # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
635+ # which fails the topological sort
636+ neweq = lhs ~ getindex_wrapper (
637+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
638+ obs[i] = neweq
639+ subeqi = findfirst (isequal (eq), subeqs)
640+ if subeqi != = nothing
641+ subeqs[subeqi] = neweq
642+ end
643+ end
644+ # end HACK 1
645+
603646 iscall (lhs) || continue
604647 operation (lhs) === getindex || continue
605648 Symbolics. shape (lhs) != Symbolics. Unknown () || continue
@@ -640,6 +683,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
640683 end
641684 append! (obs, obs_arr_eqs)
642685 append! (subeqs, obs_arr_eqs)
686+
643687 # need to re-sort subeqs
644688 subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
645689
@@ -659,6 +703,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
659703 return invalidate_cache! (sys)
660704end
661705
706+ # PART OF HACK 1
707+ getindex_wrapper (x, i) = x[i... ]
708+
709+ @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
710+
711+ # PART OF HACK 2
662712function change_origin (origin, arr)
663713 return origin (arr)
664714end
0 commit comments