@@ -896,7 +896,7 @@ Update the system equations, unknowns, and observables after simplification.
896896"""
897897function update_simplified_system! (
898898 state:: TearingState , neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
899- cse_hack = true , array_hack = true )
899+ array_hack = true )
900900 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state. structure
901901 diff_to_var = invview (var_to_diff)
902902
@@ -920,8 +920,7 @@ function update_simplified_system!(
920920 unknowns = [unknowns; extra_unknowns]
921921 @set! sys. unknowns = unknowns
922922
923- obs = cse_and_array_hacks (
924- sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
923+ obs = tearing_hacks (sys, obs, unknowns, neweqs; array = array_hack)
925924
926925 @set! sys. eqs = neweqs
927926 @set! sys. observed = obs
@@ -977,7 +976,7 @@ differential variables.
977976 according to `full_var_eq_matching`.
978977"""
979978function tearing_reassemble (state:: TearingState , var_eq_matching:: Matching ,
980- full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm, cse_hack = true ,
979+ full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm,
981980 array_hack = true , fully_determined = true )
982981 extra_eqs_vars = get_extra_eqs_vars (state, full_var_eq_matching, fully_determined)
983982 neweqs = collect (equations (state))
@@ -1010,7 +1009,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
10101009 # var_eq_matching and full_var_eq_matching are now invalidated
10111010
10121011 sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_sccs,
1013- extra_unknowns; cse_hack, array_hack)
1012+ extra_unknowns; array_hack)
10141013
10151014 @set! state. sys = sys
10161015 @set! sys. tearing_state = state
@@ -1047,60 +1046,22 @@ function get_extra_eqs_vars(
10471046end
10481047
10491048"""
1050- # HACK 1
1051-
1052- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
1053- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
1054- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
1055- avoid the unnecessary cost. This and the below hack are implemented simultaneously
1056-
1057- # HACK 2
1049+ # HACK
10581050
10591051Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
10601052equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
10611053if all `p[i]` are present and the unscalarized form is used in any equation (observed or
10621054not) we first count the number of times the scalarized form of each observed variable
10631055occurs in observed equations (and unknowns if it's split).
10641056"""
1065- function cse_and_array_hacks (sys, obs, unknowns, neweqs; cse = true , array = true )
1066- # HACK 1
1067- # mapping of rhs to temporary CSE variable
1068- # `f(...) => tmpvar` in above example
1069- rhs_to_tempvar = Dict ()
1070-
1071- # HACK 2
1057+ function tearing_hacks (sys, obs, unknowns, neweqs; array = true )
10721058 # map of array observed variable (unscalarized) to number of its
10731059 # scalarized terms that appear in observed equations
10741060 arr_obs_occurrences = Dict ()
10751061 for (i, eq) in enumerate (obs)
10761062 lhs = eq. lhs
10771063 rhs = eq. rhs
10781064
1079- # HACK 1
1080- if cse && is_getindexed_array (rhs)
1081- rhs_arr = arguments (rhs)[1 ]
1082- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
1083- if ! haskey (rhs_to_tempvar, rhs_arr)
1084- tempvar = gensym (Symbol (lhs))
1085- N = length (rhs_arr)
1086- tempvar = unwrap (Symbolics. variable (
1087- tempvar; T = Symbolics. symtype (rhs_arr)))
1088- tempvar = setmetadata (
1089- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1090- tempeq = tempvar ~ rhs_arr
1091- rhs_to_tempvar[rhs_arr] = tempvar
1092- push! (obs, tempeq)
1093- end
1094-
1095- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
1096- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
1097- # which fails the topological sort
1098- neweq = lhs ~ getindex_wrapper (
1099- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1100- obs[i] = neweq
1101- end
1102- # end HACK 1
1103-
11041065 array || continue
11051066 iscall (lhs) || continue
11061067 operation (lhs) === getindex || continue
@@ -1111,31 +1072,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
11111072 continue
11121073 end
11131074
1114- # Also do CSE for `equations(sys)`
1115- if cse
1116- for (i, eq) in enumerate (neweqs)
1117- (; lhs, rhs) = eq
1118- is_getindexed_array (rhs) || continue
1119- rhs_arr = arguments (rhs)[1 ]
1120- if ! haskey (rhs_to_tempvar, rhs_arr)
1121- tempvar = gensym (Symbol (lhs))
1122- N = length (rhs_arr)
1123- tempvar = unwrap (Symbolics. variable (
1124- tempvar; T = Symbolics. symtype (rhs_arr)))
1125- tempvar = setmetadata (
1126- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1127- tempeq = tempvar ~ rhs_arr
1128- rhs_to_tempvar[rhs_arr] = tempvar
1129- push! (obs, tempeq)
1130- end
1131- # don't need getindex_wrapper, but do it anyway to know that this
1132- # hack took place
1133- neweq = lhs ~ getindex_wrapper (
1134- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1135- neweqs[i] = neweq
1136- end
1137- end
1138-
11391075 # count variables in unknowns if they are scalarized forms of variables
11401076 # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
11411077 # is an observed equation.
@@ -1170,18 +1106,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
11701106 return obs
11711107end
11721108
1173- function is_getindexed_array (rhs)
1174- (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
1175- iscall (rhs) && operation (rhs) === getindex &&
1176- Symbolics. shape (rhs) != Symbolics. Unknown ()
1177- end
1178-
1179- # PART OF HACK 1
1180- getindex_wrapper (x, i) = x[i... ]
1181-
1182- @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
1183-
1184- # PART OF HACK 2
1109+ # PART OF HACK
11851110function change_origin (origin, arr)
11861111 if all (isone, Tuple (origin))
11871112 return arr
@@ -1209,11 +1134,11 @@ new residual equations after tearing. End users are encouraged to call [`mtkcomp
12091134instead, which calls this function internally.
12101135"""
12111136function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1212- simplify = false , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1137+ simplify = false , array_hack = true , fully_determined = true , kwargs... )
12131138 var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing (state)
12141139 invalidate_cache! (tearing_reassemble (
12151140 state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1216- simplify, cse_hack, array_hack, fully_determined))
1141+ simplify, array_hack, fully_determined))
12171142end
12181143
12191144"""
@@ -1223,7 +1148,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
12231148the system is balanced.
12241149"""
12251150function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1226- mm = nothing , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1151+ mm = nothing , array_hack = true , fully_determined = true , kwargs... )
12271152 jac = let state = state
12281153 (eqs, vars) -> begin
12291154 symeqs = EquationsView (state)[eqs]
@@ -1249,5 +1174,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
12491174 state, jac; state_priority,
12501175 kwargs... )
12511176 tearing_reassemble (state, var_eq_matching, full_var_eq_matching, var_sccs;
1252- simplify, mm, cse_hack, array_hack, fully_determined)
1177+ simplify, mm, array_hack, fully_determined)
12531178end
0 commit comments