@@ -574,30 +574,62 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574 # TODO : compute the dependency correctly so that we don't have to do this
575575 obs = [fast_substitute (observed (sys), obs_sub); subeqs]
576576
577+ unknowns = Any[v
578+ for (i, v) in enumerate (fullvars)
579+ if diff_to_var[i] === nothing && ispresent (i)]
580+ if ! isempty (extra_vars)
581+ for v in extra_vars
582+ push! (unknowns, old_fullvars[v])
583+ end
584+ end
585+ @set! sys. unknowns = unknowns
586+
577587 # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
578588 # are equations, add an equation `p ~ [p[1], p[2], ...]`
579589 # allow topsort to reorder them
590+ # only add the new equation if all `p[i]` are present
591+ # we first count the number of times the scalarized form of each observed
592+ # variable occurs in observed equations (and unknowns if it's split).
580593
581- handled_obs_arr = Set ()
582- obs_arr_eqs = Equation[]
594+ # map of array observed variable (unscalarized) to number of its
595+ # scalarized terms that appear in observed equations
596+ arr_obs_occurrences = Dict ()
583597 for eq in obs
584598 lhs = eq. lhs
585599 iscall (lhs) || continue
586600 operation (lhs) === getindex || continue
587- Symbolics. shape (lhs) != = Symbolics. Unknown () || continue
601+ Symbolics. shape (lhs) != Symbolics. Unknown () || continue
588602 arg1 = arguments (lhs)[1 ]
589- arg1 in handled_obs_arr && continue
603+ cnt = get (arr_obs_occurrences, arg1, 0 )
604+ arr_obs_occurrences[arg1] = cnt + 1
605+ continue
606+ end
607+ # count variables in unknowns if they are scalarized forms of variables
608+ # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
609+ # is an observed equation.
610+ for sym in unknowns
611+ iscall (sym) || continue
612+ operation (sym) === getindex || continue
613+ Symbolics. shape (sym) != Symbolics. Unknown () || continue
614+ arg1 = arguments (sym)[1 ]
615+ cnt = get (arr_obs_occurrences, arg1, 0 )
616+ cnt == 0 && continue
617+ arr_obs_occurrences[arg1] = cnt + 1
618+ end
619+ obs_arr_eqs = Equation[]
620+ for (arrvar, cnt) in arr_obs_occurrences
621+ cnt == length (arrvar) || continue
590622 # firstindex returns 1 for multidimensional array symbolics
591- firstind = first (eachindex (arg1))
592- scal = [arg1[i] for i in eachindex (arg1)]
623+ firstind = first (eachindex (arrvar))
624+ scal = [arrvar[i] for i in eachindex (arrvar)]
625+ all (sym -> any (eq -> isequal (sym, eq. lhs)))
593626 # respect non-1-indexed arrays
594627 # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595628 # `change_origin` is required because `Origin(firstind)(scal)` makes codegen
596629 # try to `create_array(OffsetArray{...}, ...)` which errors.
597630 # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
598631 # of `scal`.
599- push! (obs_arr_eqs, arg1 ~ change_origin (Origin (firstind), scal))
600- push! (handled_obs_arr, arg1)
632+ push! (obs_arr_eqs, arrvar ~ change_origin (Origin (firstind), scal))
601633 end
602634 append! (obs, obs_arr_eqs)
603635 append! (subeqs, obs_arr_eqs)
@@ -607,15 +639,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
607639 @set! sys. eqs = neweqs
608640 @set! sys. observed = obs
609641
610- unknowns = Any[v
611- for (i, v) in enumerate (fullvars)
612- if diff_to_var[i] === nothing && ispresent (i)]
613- if ! isempty (extra_vars)
614- for v in extra_vars
615- push! (unknowns, old_fullvars[v])
616- end
617- end
618- @set! sys. unknowns = unknowns
619642 @set! sys. substitutions = Substitutions (subeqs, deps)
620643
621644 # Only makes sense for time-dependent
0 commit comments