@@ -574,48 +574,78 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574 # TODO : compute the dependency correctly so that we don't have to do this
575575 obs = [fast_substitute (observed (sys), obs_sub); subeqs]
576576
577- # HACK: Substitute non-scalarized symbolic arrays of observed variables
578- # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579- # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580- # by the topological sorting and dependency identification pieces
581- obs_arr_subs = Dict ()
577+ unknowns = Any[v
578+ for (i, v) in enumerate (fullvars)
579+ if diff_to_var[i] === nothing && ispresent (i)]
580+ if ! isempty (extra_vars)
581+ for v in extra_vars
582+ push! (unknowns, old_fullvars[v])
583+ end
584+ end
585+ @set! sys. unknowns = unknowns
582586
587+ # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588+ # are equations, add an equation `p ~ [p[1], p[2], ...]`
589+ # allow topsort to reorder them
590+ # only add the new equation if all `p[i]` are present and the unscalarized
591+ # form is used in any equation (observed or not)
592+ # we first count the number of times the scalarized form of each observed
593+ # variable occurs in observed equations (and unknowns if it's split).
594+
595+ # map of array observed variable (unscalarized) to number of its
596+ # scalarized terms that appear in observed equations
597+ arr_obs_occurrences = Dict ()
598+ # to check if array variables occur in unscalarized form anywhere
599+ all_vars = Set ()
583600 for eq in obs
601+ vars! (all_vars, eq. rhs)
584602 lhs = eq. lhs
585603 iscall (lhs) || continue
586604 operation (lhs) === getindex || continue
587- Symbolics. shape (lhs) != = Symbolics. Unknown () || continue
605+ Symbolics. shape (lhs) != Symbolics. Unknown () || continue
588606 arg1 = arguments (lhs)[1 ]
589- haskey (obs_arr_subs, arg1) && continue
590- obs_arr_subs[arg1] = [arg1[i] for i in eachindex (arg1)] # e.g. p => [p[1], p[2]]
591- index_first = eachindex (arg1)[1 ]
592-
607+ cnt = get (arr_obs_occurrences, arg1, 0 )
608+ arr_obs_occurrences[arg1] = cnt + 1
609+ continue
610+ end
611+ # count variables in unknowns if they are scalarized forms of variables
612+ # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
613+ # is an observed equation.
614+ for sym in unknowns
615+ iscall (sym) || continue
616+ operation (sym) === getindex || continue
617+ Symbolics. shape (sym) != Symbolics. Unknown () || continue
618+ arg1 = arguments (sym)[1 ]
619+ cnt = get (arr_obs_occurrences, arg1, 0 )
620+ cnt == 0 && continue
621+ arr_obs_occurrences[arg1] = cnt + 1
622+ end
623+ for eq in neweqs
624+ vars! (all_vars, eq. rhs)
625+ end
626+ obs_arr_eqs = Equation[]
627+ for (arrvar, cnt) in arr_obs_occurrences
628+ cnt == length (arrvar) || continue
629+ arrvar in all_vars || continue
630+ # firstindex returns 1 for multidimensional array symbolics
631+ firstind = first (eachindex (arrvar))
632+ scal = [arrvar[i] for i in eachindex (arrvar)]
593633 # respect non-1-indexed arrays
594634 # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595- obs_arr_subs[arg1] = Origin (index_first)(obs_arr_subs[arg1])
596- end
597- for i in eachindex (neweqs)
598- neweqs[i] = fast_substitute (neweqs[i], obs_arr_subs; operator = Symbolics. Operator)
599- end
600- for i in eachindex (obs)
601- obs[i] = fast_substitute (obs[i], obs_arr_subs; operator = Symbolics. Operator)
602- end
603- for i in eachindex (subeqs)
604- subeqs[i] = fast_substitute (subeqs[i], obs_arr_subs; operator = Symbolics. Operator)
635+ # `change_origin` is required because `Origin(firstind)(scal)` makes codegen
636+ # try to `create_array(OffsetArray{...}, ...)` which errors.
637+ # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
638+ # of `scal`.
639+ push! (obs_arr_eqs, arrvar ~ change_origin (Origin (firstind), scal))
605640 end
641+ append! (obs, obs_arr_eqs)
642+ append! (subeqs, obs_arr_eqs)
643+ # need to re-sort subeqs
644+ subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
606645
607646 @set! sys. eqs = neweqs
608647 @set! sys. observed = obs
609648
610- unknowns = Any[v
611- for (i, v) in enumerate (fullvars)
612- if diff_to_var[i] === nothing && ispresent (i)]
613- if ! isempty (extra_vars)
614- for v in extra_vars
615- push! (unknowns, old_fullvars[v])
616- end
617- end
618- @set! sys. unknowns = unknowns
619649 @set! sys. substitutions = Substitutions (subeqs, deps)
620650
621651 # Only makes sense for time-dependent
@@ -629,6 +659,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629659 return invalidate_cache! (sys)
630660end
631661
662+ function change_origin (origin, arr)
663+ return origin (arr)
664+ end
665+
666+ @register_array_symbolic change_origin (origin:: Origin , arr:: AbstractArray ) begin
667+ size = size (arr)
668+ eltype = eltype (arr)
669+ ndims = ndims (arr)
670+ end
671+
632672function tearing (state:: TearingState ; kwargs... )
633673 state. structure. solvable_graph === nothing && find_solvables! (state; kwargs... )
634674 complete! (state. structure)
0 commit comments