@@ -724,7 +724,7 @@ Update the system equations, unknowns, and observables after simplification.
724724"""
725725function update_simplified_system! (
726726 state:: TearingState , neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
727- cse_hack = true , array_hack = true )
727+ array_hack = true )
728728 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state. structure
729729 diff_to_var = invview (var_to_diff)
730730
@@ -748,8 +748,7 @@ function update_simplified_system!(
748748 unknowns = [unknowns; extra_unknowns]
749749 @set! sys. unknowns = unknowns
750750
751- obs = cse_and_array_hacks (
752- sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
751+ obs = tearing_hacks (sys, obs, unknowns, neweqs; array = array_hack)
753752
754753 deps = Vector{Int}[i == 1 ? Int[] : collect (1 : (i - 1 ))
755754 for i in 1 : length (solved_eqs)]
@@ -793,7 +792,7 @@ appear in the system. Algebraic variables are variables that are not
793792differential variables.
794793"""
795794function tearing_reassemble (state:: TearingState , var_eq_matching,
796- full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
795+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , array_hack = true )
797796 extra_vars = Int[]
798797 if full_var_eq_matching != = nothing
799798 for v in 𝑑vertices (state. structure. graph)
@@ -829,68 +828,30 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
829828 state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
830829
831830 sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
832- extra_unknowns; cse_hack, array_hack)
831+ extra_unknowns; array_hack)
833832
834833 @set! state. sys = sys
835834 @set! sys. tearing_state = state
836835 return invalidate_cache! (sys)
837836end
838837
839838"""
840- # HACK 1
841-
842- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
843- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
844- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
845- avoid the unnecessary cost. This and the below hack are implemented simultaneously
846-
847- # HACK 2
839+ # HACK
848840
849841Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
850842equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
851843if all `p[i]` are present and the unscalarized form is used in any equation (observed or
852844not) we first count the number of times the scalarized form of each observed variable
853845occurs in observed equations (and unknowns if it's split).
854846"""
855- function cse_and_array_hacks (sys, obs, unknowns, neweqs; cse = true , array = true )
856- # HACK 1
857- # mapping of rhs to temporary CSE variable
858- # `f(...) => tmpvar` in above example
859- rhs_to_tempvar = Dict ()
860-
861- # HACK 2
847+ function tearing_hacks (sys, obs, unknowns, neweqs; array = true )
862848 # map of array observed variable (unscalarized) to number of its
863849 # scalarized terms that appear in observed equations
864850 arr_obs_occurrences = Dict ()
865851 for (i, eq) in enumerate (obs)
866852 lhs = eq. lhs
867853 rhs = eq. rhs
868854
869- # HACK 1
870- if cse && is_getindexed_array (rhs)
871- rhs_arr = arguments (rhs)[1 ]
872- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
873- if ! haskey (rhs_to_tempvar, rhs_arr)
874- tempvar = gensym (Symbol (lhs))
875- N = length (rhs_arr)
876- tempvar = unwrap (Symbolics. variable (
877- tempvar; T = Symbolics. symtype (rhs_arr)))
878- tempvar = setmetadata (
879- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
880- tempeq = tempvar ~ rhs_arr
881- rhs_to_tempvar[rhs_arr] = tempvar
882- push! (obs, tempeq)
883- end
884-
885- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
886- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
887- # which fails the topological sort
888- neweq = lhs ~ getindex_wrapper (
889- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
890- obs[i] = neweq
891- end
892- # end HACK 1
893-
894855 array || continue
895856 iscall (lhs) || continue
896857 operation (lhs) === getindex || continue
@@ -901,31 +862,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
901862 continue
902863 end
903864
904- # Also do CSE for `equations(sys)`
905- if cse
906- for (i, eq) in enumerate (neweqs)
907- (; lhs, rhs) = eq
908- is_getindexed_array (rhs) || continue
909- rhs_arr = arguments (rhs)[1 ]
910- if ! haskey (rhs_to_tempvar, rhs_arr)
911- tempvar = gensym (Symbol (lhs))
912- N = length (rhs_arr)
913- tempvar = unwrap (Symbolics. variable (
914- tempvar; T = Symbolics. symtype (rhs_arr)))
915- tempvar = setmetadata (
916- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
917- tempeq = tempvar ~ rhs_arr
918- rhs_to_tempvar[rhs_arr] = tempvar
919- push! (obs, tempeq)
920- end
921- # don't need getindex_wrapper, but do it anyway to know that this
922- # hack took place
923- neweq = lhs ~ getindex_wrapper (
924- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
925- neweqs[i] = neweq
926- end
927- end
928-
929865 # count variables in unknowns if they are scalarized forms of variables
930866 # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
931867 # is an observed equation.
@@ -960,18 +896,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
960896 return obs
961897end
962898
963- function is_getindexed_array (rhs)
964- (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
965- iscall (rhs) && operation (rhs) === getindex &&
966- Symbolics. shape (rhs) != Symbolics. Unknown ()
967- end
968-
969- # PART OF HACK 1
970- getindex_wrapper (x, i) = x[i... ]
971-
972- @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
973-
974- # PART OF HACK 2
899+ # PART OF HACK
975900function change_origin (origin, arr)
976901 if all (isone, Tuple (origin))
977902 return arr
@@ -999,10 +924,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
999924instead, which calls this function internally.
1000925"""
1001926function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1002- simplify = false , cse_hack = true , array_hack = true , kwargs... )
927+ simplify = false , array_hack = true , kwargs... )
1003928 var_eq_matching, full_var_eq_matching = tearing (state)
1004929 invalidate_cache! (tearing_reassemble (
1005- state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
930+ state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1006931end
1007932
1008933"""
@@ -1024,7 +949,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1024949the system is balanced.
1025950"""
1026951function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1027- mm = nothing , cse_hack = true , array_hack = true , kwargs... )
952+ mm = nothing , array_hack = true , kwargs... )
1028953 jac = let state = state
1029954 (eqs, vars) -> begin
1030955 symeqs = EquationsView (state)[eqs]
@@ -1048,5 +973,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1048973 end
1049974 var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
1050975 kwargs... )
1051- tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack)
976+ tearing_reassemble (state, var_eq_matching; simplify, mm, array_hack)
1052977end
0 commit comments