@@ -593,7 +593,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
593
593
@set! sys. unknowns = unknowns
594
594
595
595
obs, subeqs, deps = cse_and_array_hacks (
596
- obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
596
+ sys, obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
597
597
598
598
@set! sys. eqs = neweqs
599
599
@set! sys. observed = obs
@@ -627,7 +627,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
627
627
not) we first count the number of times the scalarized form of each observed variable
628
628
occurs in observed equations (and unknowns if it's split).
629
629
"""
630
- function cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse = true , array = true )
630
+ function cse_and_array_hacks (sys, obs, subeqs, unknowns, neweqs; cse = true , array = true )
631
631
# HACK 1
632
632
# mapping of rhs to temporary CSE variable
633
633
# `f(...) => tmpvar` in above example
@@ -725,6 +725,11 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
725
725
for eq in neweqs
726
726
vars! (all_vars, eq. rhs)
727
727
end
728
+
729
+ # also count unscalarized variables used in callbacks
730
+ for ev in Iterators. flatten ((continuous_events (sys), discrete_events (sys)))
731
+ vars! (all_vars, ev)
732
+ end
728
733
obs_arr_eqs = Equation[]
729
734
for (arrvar, cnt) in arr_obs_occurrences
730
735
cnt == length (arrvar) || continue
0 commit comments