@@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584 end
585585 @set! sys. unknowns = unknowns
586586
587+ # HACK: Since we don't support array equations, any equation of the sort
588+ # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589+ # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590+ # for this case to avoid the unnecessary cost.
591+ # This and the below hack are implemented simultaneously
592+
593+ # mapping of rhs to temporary CSE variable
594+ # `f(...) => tmpvar` in above example
595+ rhs_to_tempvar = Dict ()
596+
587597 # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588598 # are equations, add an equation `p ~ [p[1], p[2], ...]`
589599 # allow topsort to reorder them
@@ -597,9 +607,44 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
597607 arr_obs_occurrences = Dict ()
598608 # to check if array variables occur in unscalarized form anywhere
599609 all_vars = Set ()
600- for eq in obs
601- vars! (all_vars, eq. rhs)
610+ for (i, eq) in enumerate (obs)
602611 lhs = eq. lhs
612+ rhs = eq. rhs
613+ vars! (all_vars, rhs)
614+
615+ # HACK 1
616+ if iscall (rhs) && operation (rhs) === getindex &&
617+ Symbolics. shape (rhs) != Symbolics. Unknown ()
618+ rhs_arr = arguments (rhs)[1 ]
619+ if ! haskey (rhs_to_tempvar, rhs_arr)
620+ tempvar = gensym (Symbol (lhs))
621+ N = length (rhs_arr)
622+ tempvar = unwrap (Symbolics. variable (
623+ tempvar; T = Symbolics. symtype (rhs_arr)))
624+ tempvar = setmetadata (
625+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
626+ tempeq = tempvar ~ rhs_arr
627+ rhs_to_tempvar[rhs_arr] = tempvar
628+ # ideally we would like to do this:
629+ push! (obs, tempeq)
630+ push! (subeqs, tempeq)
631+ # and let topsort_equations handle it, but that treats `x` and `x[1]`
632+ # as different variables and thus doesn
633+ end
634+
635+ # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
636+ # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
637+ # which fails the topological sort
638+ neweq = lhs ~ getindex_wrapper (
639+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
640+ obs[i] = neweq
641+ subeqi = findfirst (isequal (eq), subeqs)
642+ if subeqi != = nothing
643+ subeqs[subeqi] = neweq
644+ end
645+ end
646+ # end HACK 1
647+
603648 iscall (lhs) || continue
604649 operation (lhs) === getindex || continue
605650 Symbolics. shape (lhs) != Symbolics. Unknown () || continue
@@ -640,6 +685,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
640685 end
641686 append! (obs, obs_arr_eqs)
642687 append! (subeqs, obs_arr_eqs)
688+
643689 # need to re-sort subeqs
644690 subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
645691
@@ -659,6 +705,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
659705 return invalidate_cache! (sys)
660706end
661707
708+ # PART OF HACK 1
709+ getindex_wrapper (x, i) = x[i... ]
710+
711+ @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
712+
713+ # PART OF HACK 2
662714function change_origin (origin, arr)
663715 return origin (arr)
664716end
0 commit comments