@@ -574,48 +574,78 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574
574
# TODO : compute the dependency correctly so that we don't have to do this
575
575
obs = [fast_substitute (observed (sys), obs_sub); subeqs]
576
576
577
- # HACK: Substitute non-scalarized symbolic arrays of observed variables
578
- # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579
- # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580
- # by the topological sorting and dependency identification pieces
581
- obs_arr_subs = Dict ()
577
+ unknowns = Any[v
578
+ for (i, v) in enumerate (fullvars)
579
+ if diff_to_var[i] === nothing && ispresent (i)]
580
+ if ! isempty (extra_vars)
581
+ for v in extra_vars
582
+ push! (unknowns, old_fullvars[v])
583
+ end
584
+ end
585
+ @set! sys. unknowns = unknowns
582
586
587
+ # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588
+ # are equations, add an equation `p ~ [p[1], p[2], ...]`
589
+ # allow topsort to reorder them
590
+ # only add the new equation if all `p[i]` are present and the unscalarized
591
+ # form is used in any equation (observed or not)
592
+ # we first count the number of times the scalarized form of each observed
593
+ # variable occurs in observed equations (and unknowns if it's split).
594
+
595
+ # map of array observed variable (unscalarized) to number of its
596
+ # scalarized terms that appear in observed equations
597
+ arr_obs_occurrences = Dict ()
598
+ # to check if array variables occur in unscalarized form anywhere
599
+ all_vars = Set ()
583
600
for eq in obs
601
+ vars! (all_vars, eq. rhs)
584
602
lhs = eq. lhs
585
603
iscall (lhs) || continue
586
604
operation (lhs) === getindex || continue
587
- Symbolics. shape (lhs) != = Symbolics. Unknown () || continue
605
+ Symbolics. shape (lhs) != Symbolics. Unknown () || continue
588
606
arg1 = arguments (lhs)[1 ]
589
- haskey (obs_arr_subs, arg1) && continue
590
- obs_arr_subs[arg1] = [arg1[i] for i in eachindex (arg1)] # e.g. p => [p[1], p[2]]
591
- index_first = eachindex (arg1)[1 ]
592
-
607
+ cnt = get (arr_obs_occurrences, arg1, 0 )
608
+ arr_obs_occurrences[arg1] = cnt + 1
609
+ continue
610
+ end
611
+ # count variables in unknowns if they are scalarized forms of variables
612
+ # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
613
+ # is an observed equation.
614
+ for sym in unknowns
615
+ iscall (sym) || continue
616
+ operation (sym) === getindex || continue
617
+ Symbolics. shape (sym) != Symbolics. Unknown () || continue
618
+ arg1 = arguments (sym)[1 ]
619
+ cnt = get (arr_obs_occurrences, arg1, 0 )
620
+ cnt == 0 && continue
621
+ arr_obs_occurrences[arg1] = cnt + 1
622
+ end
623
+ for eq in neweqs
624
+ vars! (all_vars, eq. rhs)
625
+ end
626
+ obs_arr_eqs = Equation[]
627
+ for (arrvar, cnt) in arr_obs_occurrences
628
+ cnt == length (arrvar) || continue
629
+ arrvar in all_vars || continue
630
+ # firstindex returns 1 for multidimensional array symbolics
631
+ firstind = first (eachindex (arrvar))
632
+ scal = [arrvar[i] for i in eachindex (arrvar)]
593
633
# respect non-1-indexed arrays
594
634
# TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595
- obs_arr_subs[arg1] = Origin (index_first)(obs_arr_subs[arg1])
596
- end
597
- for i in eachindex (neweqs)
598
- neweqs[i] = fast_substitute (neweqs[i], obs_arr_subs; operator = Symbolics. Operator)
599
- end
600
- for i in eachindex (obs)
601
- obs[i] = fast_substitute (obs[i], obs_arr_subs; operator = Symbolics. Operator)
602
- end
603
- for i in eachindex (subeqs)
604
- subeqs[i] = fast_substitute (subeqs[i], obs_arr_subs; operator = Symbolics. Operator)
635
+ # `change_origin` is required because `Origin(firstind)(scal)` makes codegen
636
+ # try to `create_array(OffsetArray{...}, ...)` which errors.
637
+ # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
638
+ # of `scal`.
639
+ push! (obs_arr_eqs, arrvar ~ change_origin (Origin (firstind), scal))
605
640
end
641
+ append! (obs, obs_arr_eqs)
642
+ append! (subeqs, obs_arr_eqs)
643
+ # need to re-sort subeqs
644
+ subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
606
645
607
646
@set! sys. eqs = neweqs
608
647
@set! sys. observed = obs
609
648
610
- unknowns = Any[v
611
- for (i, v) in enumerate (fullvars)
612
- if diff_to_var[i] === nothing && ispresent (i)]
613
- if ! isempty (extra_vars)
614
- for v in extra_vars
615
- push! (unknowns, old_fullvars[v])
616
- end
617
- end
618
- @set! sys. unknowns = unknowns
619
649
@set! sys. substitutions = Substitutions (subeqs, deps)
620
650
621
651
# Only makes sense for time-dependent
@@ -629,6 +659,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629
659
return invalidate_cache! (sys)
630
660
end
631
661
662
+ function change_origin (origin, arr)
663
+ return origin (arr)
664
+ end
665
+
666
+ @register_array_symbolic change_origin (origin:: Origin , arr:: AbstractArray ) begin
667
+ size = size (arr)
668
+ eltype = eltype (arr)
669
+ ndims = ndims (arr)
670
+ end
671
+
632
672
function tearing (state:: TearingState ; kwargs... )
633
673
state. structure. solvable_graph === nothing && find_solvables! (state; kwargs... )
634
674
complete! (state. structure)
0 commit comments