@@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584
584
end
585
585
@set! sys. unknowns = unknowns
586
586
587
+ # HACK: Since we don't support array equations, any equation of the sort
588
+ # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589
+ # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590
+ # for this case to avoid the unnecessary cost.
591
+ # This and the below hack are implemented simultaneously
592
+
593
+ # mapping of rhs to temporary CSE variable
594
+ # `f(...) => tmpvar` in above example
595
+ rhs_to_tempvar = Dict ()
596
+
587
597
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588
598
# are equations, add an equation `p ~ [p[1], p[2], ...]`
589
599
# allow topsort to reorder them
@@ -597,9 +607,42 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
597
607
arr_obs_occurrences = Dict ()
598
608
# to check if array variables occur in unscalarized form anywhere
599
609
all_vars = Set ()
600
- for eq in obs
601
- vars! (all_vars, eq. rhs)
610
+ for (i, eq) in enumerate (obs)
602
611
lhs = eq. lhs
612
+ rhs = eq. rhs
613
+ vars! (all_vars, rhs)
614
+
615
+ # HACK 1
616
+ if (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
617
+ iscall (rhs) && operation (rhs) === getindex &&
618
+ Symbolics. shape (rhs) != Symbolics. Unknown ()
619
+ rhs_arr = arguments (rhs)[1 ]
620
+ if ! haskey (rhs_to_tempvar, rhs_arr)
621
+ tempvar = gensym (Symbol (lhs))
622
+ N = length (rhs_arr)
623
+ tempvar = unwrap (Symbolics. variable (
624
+ tempvar; T = Symbolics. symtype (rhs_arr)))
625
+ tempvar = setmetadata (
626
+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
627
+ tempeq = tempvar ~ rhs_arr
628
+ rhs_to_tempvar[rhs_arr] = tempvar
629
+ push! (obs, tempeq)
630
+ push! (subeqs, tempeq)
631
+ end
632
+
633
+ # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
634
+ # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
635
+ # which fails the topological sort
636
+ neweq = lhs ~ getindex_wrapper (
637
+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
638
+ obs[i] = neweq
639
+ subeqi = findfirst (isequal (eq), subeqs)
640
+ if subeqi != = nothing
641
+ subeqs[subeqi] = neweq
642
+ end
643
+ end
644
+ # end HACK 1
645
+
603
646
iscall (lhs) || continue
604
647
operation (lhs) === getindex || continue
605
648
Symbolics. shape (lhs) != Symbolics. Unknown () || continue
@@ -640,6 +683,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
640
683
end
641
684
append! (obs, obs_arr_eqs)
642
685
append! (subeqs, obs_arr_eqs)
686
+
643
687
# need to re-sort subeqs
644
688
subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
645
689
@@ -659,6 +703,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
659
703
return invalidate_cache! (sys)
660
704
end
661
705
706
+ # PART OF HACK 1
707
+ getindex_wrapper (x, i) = x[i... ]
708
+
709
+ @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
710
+
711
+ # PART OF HACK 2
662
712
function change_origin (origin, arr)
663
713
return origin (arr)
664
714
end
0 commit comments