230
230
=#
231
231
232
232
function tearing_reassemble (state:: TearingState , var_eq_matching,
233
- full_var_eq_matching = nothing ; simplify = false , mm = nothing )
233
+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
234
234
@unpack fullvars, sys, structure = state
235
235
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236
236
extra_vars = Int[]
@@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574
574
# TODO : compute the dependency correctly so that we don't have to do this
575
575
obs = [fast_substitute (observed (sys), obs_sub); subeqs]
576
576
577
- # HACK: Substitute non-scalarized symbolic arrays of observed variables
578
- # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579
- # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580
- # by the topological sorting and dependency identification pieces
581
- obs_arr_subs = Dict ()
582
-
583
- for eq in obs
584
- lhs = eq. lhs
585
- iscall (lhs) || continue
586
- operation (lhs) === getindex || continue
587
- Symbolics. shape (lhs) != = Symbolics. Unknown () || continue
588
- arg1 = arguments (lhs)[1 ]
589
- haskey (obs_arr_subs, arg1) && continue
590
- obs_arr_subs[arg1] = [arg1[i] for i in eachindex (arg1)] # e.g. p => [p[1], p[2]]
591
- index_first = eachindex (arg1)[1 ]
592
-
593
- # respect non-1-indexed arrays
594
- # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595
- obs_arr_subs[arg1] = Origin (index_first)(obs_arr_subs[arg1])
596
- end
597
- for i in eachindex (neweqs)
598
- neweqs[i] = fast_substitute (neweqs[i], obs_arr_subs; operator = Symbolics. Operator)
599
- end
600
- for i in eachindex (obs)
601
- obs[i] = fast_substitute (obs[i], obs_arr_subs; operator = Symbolics. Operator)
602
- end
603
- for i in eachindex (subeqs)
604
- subeqs[i] = fast_substitute (subeqs[i], obs_arr_subs; operator = Symbolics. Operator)
605
- end
606
-
607
- @set! sys. eqs = neweqs
608
- @set! sys. observed = obs
609
-
610
577
unknowns = Any[v
611
578
for (i, v) in enumerate (fullvars)
612
579
if diff_to_var[i] === nothing && ispresent (i)]
@@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
616
583
end
617
584
end
618
585
@set! sys. unknowns = unknowns
586
+
587
+ obs, subeqs, deps = cse_and_array_hacks (
588
+ obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
589
+
590
+ @set! sys. eqs = neweqs
591
+ @set! sys. observed = obs
592
+
619
593
@set! sys. substitutions = Substitutions (subeqs, deps)
620
594
621
595
# Only makes sense for time-dependent
@@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629
603
return invalidate_cache! (sys)
630
604
end
631
605
606
+ """
607
+ # HACK 1
608
+
609
+ Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610
+ gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611
+ _very_ expensive. this hack performs a limited form of CSE specifically for this case to
612
+ avoid the unnecessary cost. This and the below hack are implemented simultaneously
613
+
614
+ # HACK 2
615
+
616
+ Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617
+ equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618
+ if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619
+ not) we first count the number of times the scalarized form of each observed variable
620
+ occurs in observed equations (and unknowns if it's split).
621
+ """
622
+ function cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse = true , array = true )
623
+ # HACK 1
624
+ # mapping of rhs to temporary CSE variable
625
+ # `f(...) => tmpvar` in above example
626
+ rhs_to_tempvar = Dict ()
627
+
628
+ # HACK 2
629
+ # map of array observed variable (unscalarized) to number of its
630
+ # scalarized terms that appear in observed equations
631
+ arr_obs_occurrences = Dict ()
632
+ # to check if array variables occur in unscalarized form anywhere
633
+ all_vars = Set ()
634
+ for (i, eq) in enumerate (obs)
635
+ lhs = eq. lhs
636
+ rhs = eq. rhs
637
+ vars! (all_vars, rhs)
638
+
639
+ # HACK 1
640
+ if cse && is_getindexed_array (rhs)
641
+ rhs_arr = arguments (rhs)[1 ]
642
+ if ! haskey (rhs_to_tempvar, rhs_arr)
643
+ tempvar = gensym (Symbol (lhs))
644
+ N = length (rhs_arr)
645
+ tempvar = unwrap (Symbolics. variable (
646
+ tempvar; T = Symbolics. symtype (rhs_arr)))
647
+ tempvar = setmetadata (
648
+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
649
+ tempeq = tempvar ~ rhs_arr
650
+ rhs_to_tempvar[rhs_arr] = tempvar
651
+ push! (obs, tempeq)
652
+ push! (subeqs, tempeq)
653
+ end
654
+
655
+ # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
656
+ # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
657
+ # which fails the topological sort
658
+ neweq = lhs ~ getindex_wrapper (
659
+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
660
+ obs[i] = neweq
661
+ subeqi = findfirst (isequal (eq), subeqs)
662
+ if subeqi != = nothing
663
+ subeqs[subeqi] = neweq
664
+ end
665
+ end
666
+ # end HACK 1
667
+
668
+ array || continue
669
+ iscall (lhs) || continue
670
+ operation (lhs) === getindex || continue
671
+ Symbolics. shape (lhs) != Symbolics. Unknown () || continue
672
+ arg1 = arguments (lhs)[1 ]
673
+ cnt = get (arr_obs_occurrences, arg1, 0 )
674
+ arr_obs_occurrences[arg1] = cnt + 1
675
+ continue
676
+ end
677
+
678
+ # Also do CSE for `equations(sys)`
679
+ if cse
680
+ for (i, eq) in enumerate (neweqs)
681
+ (; lhs, rhs) = eq
682
+ is_getindexed_array (rhs) || continue
683
+ rhs_arr = arguments (rhs)[1 ]
684
+ if ! haskey (rhs_to_tempvar, rhs_arr)
685
+ tempvar = gensym (Symbol (lhs))
686
+ N = length (rhs_arr)
687
+ tempvar = unwrap (Symbolics. variable (
688
+ tempvar; T = Symbolics. symtype (rhs_arr)))
689
+ tempvar = setmetadata (
690
+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
691
+ tempeq = tempvar ~ rhs_arr
692
+ rhs_to_tempvar[rhs_arr] = tempvar
693
+ push! (obs, tempeq)
694
+ push! (subeqs, tempeq)
695
+ end
696
+ # don't need getindex_wrapper, but do it anyway to know that this
697
+ # hack took place
698
+ neweq = lhs ~ getindex_wrapper (
699
+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
700
+ neweqs[i] = neweq
701
+ end
702
+ end
703
+
704
+ # count variables in unknowns if they are scalarized forms of variables
705
+ # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
706
+ # is an observed equation.
707
+ for sym in unknowns
708
+ iscall (sym) || continue
709
+ operation (sym) === getindex || continue
710
+ Symbolics. shape (sym) != Symbolics. Unknown () || continue
711
+ arg1 = arguments (sym)[1 ]
712
+ cnt = get (arr_obs_occurrences, arg1, 0 )
713
+ cnt == 0 && continue
714
+ arr_obs_occurrences[arg1] = cnt + 1
715
+ end
716
+ for eq in neweqs
717
+ vars! (all_vars, eq. rhs)
718
+ end
719
+ obs_arr_eqs = Equation[]
720
+ for (arrvar, cnt) in arr_obs_occurrences
721
+ cnt == length (arrvar) || continue
722
+ arrvar in all_vars || continue
723
+ # firstindex returns 1 for multidimensional array symbolics
724
+ firstind = first (eachindex (arrvar))
725
+ scal = [arrvar[i] for i in eachindex (arrvar)]
726
+ # respect non-1-indexed arrays
727
+ # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
728
+ # `change_origin` is required because `Origin(firstind)(scal)` makes codegen
729
+ # try to `create_array(OffsetArray{...}, ...)` which errors.
730
+ # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
731
+ # of `scal`.
732
+ push! (obs_arr_eqs, arrvar ~ change_origin (Origin (firstind), scal))
733
+ end
734
+ append! (obs, obs_arr_eqs)
735
+ append! (subeqs, obs_arr_eqs)
736
+
737
+ # need to re-sort subeqs
738
+ subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
739
+
740
+ deps = Vector{Int}[i == 1 ? Int[] : collect (1 : (i - 1 ))
741
+ for i in 1 : length (subeqs)]
742
+
743
+ return obs, subeqs, deps
744
+ end
745
+
746
+ function is_getindexed_array (rhs)
747
+ (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
748
+ iscall (rhs) && operation (rhs) === getindex &&
749
+ Symbolics. shape (rhs) != Symbolics. Unknown ()
750
+ end
751
+
752
+ # PART OF HACK 1
753
+ getindex_wrapper (x, i) = x[i... ]
754
+
755
+ @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
756
+
757
+ # PART OF HACK 2
758
+ function change_origin (origin, arr)
759
+ return origin (arr)
760
+ end
761
+
762
+ @register_array_symbolic change_origin (origin:: Origin , arr:: AbstractArray ) begin
763
+ size = size (arr)
764
+ eltype = eltype (arr)
765
+ ndims = ndims (arr)
766
+ end
767
+
632
768
function tearing (state:: TearingState ; kwargs... )
633
769
state. structure. solvable_graph === nothing && find_solvables! (state; kwargs... )
634
770
complete! (state. structure)
@@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
643
779
instead, which calls this function internally.
644
780
"""
645
781
function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
646
- simplify = false , kwargs... )
782
+ simplify = false , cse_hack = true , array_hack = true , kwargs... )
647
783
var_eq_matching, full_var_eq_matching = tearing (state)
648
784
invalidate_cache! (tearing_reassemble (
649
- state, var_eq_matching, full_var_eq_matching; mm, simplify))
785
+ state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack ))
650
786
end
651
787
652
788
"""
@@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
668
804
the system is balanced.
669
805
"""
670
806
function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
671
- mm = nothing , kwargs... )
807
+ mm = nothing , cse_hack = true , array_hack = true , kwargs... )
672
808
jac = let state = state
673
809
(eqs, vars) -> begin
674
810
symeqs = EquationsView (state)[eqs]
@@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
692
828
end
693
829
var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
694
830
kwargs... )
695
- tearing_reassemble (state, var_eq_matching; simplify, mm)
831
+ tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack )
696
832
end
0 commit comments