@@ -929,7 +929,7 @@ Update the system equations, unknowns, and observables after simplification.
929929"""
930930function update_simplified_system! (
931931 state:: TearingState , neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
932- cse_hack = true , array_hack = true , D = nothing , iv = nothing )
932+ array_hack = true , D = nothing , iv = nothing )
933933 @unpack fullvars, structure = state
934934 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
935935 diff_to_var = invview (var_to_diff)
@@ -978,8 +978,7 @@ function update_simplified_system!(
978978 end
979979 @set! sys. unknowns = unknowns
980980
981- obs = cse_and_array_hacks (
982- sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
981+ obs = tearing_hacks (sys, obs, unknowns, neweqs; array = array_hack)
983982
984983 @set! sys. eqs = neweqs
985984 @set! sys. observed = obs
@@ -1035,7 +1034,7 @@ differential variables.
10351034 according to `full_var_eq_matching`.
10361035"""
10371036function tearing_reassemble (state:: TearingState , var_eq_matching:: Matching ,
1038- full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm, cse_hack = true ,
1037+ full_var_eq_matching:: Matching , var_sccs:: Vector{Vector{Int}} ; simplify = false , mm,
10391038 array_hack = true , fully_determined = true )
10401039 extra_eqs_vars = get_extra_eqs_vars (
10411040 state, var_eq_matching, full_var_eq_matching, fully_determined)
@@ -1074,7 +1073,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
10741073 # var_eq_matching and full_var_eq_matching are now invalidated
10751074
10761075 sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_sccs,
1077- extra_unknowns; cse_hack, array_hack, iv, D)
1076+ extra_unknowns; array_hack, iv, D)
10781077
10791078 @set! state. sys = sys
10801079 @set! sys. tearing_state = state
@@ -1223,60 +1222,22 @@ function get_extra_eqs_vars(
12231222end
12241223
12251224"""
1226- # HACK 1
1227-
1228- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
1229- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
1230- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
1231- avoid the unnecessary cost. This and the below hack are implemented simultaneously
1232-
1233- # HACK 2
1225+ # HACK
12341226
12351227Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
12361228equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
12371229if all `p[i]` are present and the unscalarized form is used in any equation (observed or
12381230not) we first count the number of times the scalarized form of each observed variable
12391231occurs in observed equations (and unknowns if it's split).
12401232"""
1241- function cse_and_array_hacks (sys, obs, unknowns, neweqs; cse = true , array = true )
1242- # HACK 1
1243- # mapping of rhs to temporary CSE variable
1244- # `f(...) => tmpvar` in above example
1245- rhs_to_tempvar = Dict ()
1246-
1247- # HACK 2
1233+ function tearing_hacks (sys, obs, unknowns, neweqs; array = true )
12481234 # map of array observed variable (unscalarized) to number of its
12491235 # scalarized terms that appear in observed equations
12501236 arr_obs_occurrences = Dict ()
12511237 for (i, eq) in enumerate (obs)
12521238 lhs = eq. lhs
12531239 rhs = eq. rhs
12541240
1255- # HACK 1
1256- if cse && is_getindexed_array (rhs)
1257- rhs_arr = arguments (rhs)[1 ]
1258- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
1259- if ! haskey (rhs_to_tempvar, rhs_arr)
1260- tempvar = gensym (Symbol (lhs))
1261- N = length (rhs_arr)
1262- tempvar = unwrap (Symbolics. variable (
1263- tempvar; T = Symbolics. symtype (rhs_arr)))
1264- tempvar = setmetadata (
1265- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1266- tempeq = tempvar ~ rhs_arr
1267- rhs_to_tempvar[rhs_arr] = tempvar
1268- push! (obs, tempeq)
1269- end
1270-
1271- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
1272- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
1273- # which fails the topological sort
1274- neweq = lhs ~ getindex_wrapper (
1275- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1276- obs[i] = neweq
1277- end
1278- # end HACK 1
1279-
12801241 array || continue
12811242 iscall (lhs) || continue
12821243 operation (lhs) === getindex || continue
@@ -1287,31 +1248,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
12871248 continue
12881249 end
12891250
1290- # Also do CSE for `equations(sys)`
1291- if cse
1292- for (i, eq) in enumerate (neweqs)
1293- (; lhs, rhs) = eq
1294- is_getindexed_array (rhs) || continue
1295- rhs_arr = arguments (rhs)[1 ]
1296- if ! haskey (rhs_to_tempvar, rhs_arr)
1297- tempvar = gensym (Symbol (lhs))
1298- N = length (rhs_arr)
1299- tempvar = unwrap (Symbolics. variable (
1300- tempvar; T = Symbolics. symtype (rhs_arr)))
1301- tempvar = setmetadata (
1302- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
1303- tempeq = tempvar ~ rhs_arr
1304- rhs_to_tempvar[rhs_arr] = tempvar
1305- push! (obs, tempeq)
1306- end
1307- # don't need getindex_wrapper, but do it anyway to know that this
1308- # hack took place
1309- neweq = lhs ~ getindex_wrapper (
1310- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
1311- neweqs[i] = neweq
1312- end
1313- end
1314-
13151251 # count variables in unknowns if they are scalarized forms of variables
13161252 # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
13171253 # is an observed equation.
@@ -1346,18 +1282,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
13461282 return obs
13471283end
13481284
1349- function is_getindexed_array (rhs)
1350- (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
1351- iscall (rhs) && operation (rhs) === getindex &&
1352- Symbolics. shape (rhs) != Symbolics. Unknown ()
1353- end
1354-
1355- # PART OF HACK 1
1356- getindex_wrapper (x, i) = x[i... ]
1357-
1358- @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
1359-
1360- # PART OF HACK 2
1285+ # PART OF HACK
13611286function change_origin (origin, arr)
13621287 if all (isone, Tuple (origin))
13631288 return arr
@@ -1385,11 +1310,11 @@ new residual equations after tearing. End users are encouraged to call [`mtkcomp
13851310instead, which calls this function internally.
13861311"""
13871312function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1388- simplify = false , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1313+ simplify = false , array_hack = true , fully_determined = true , kwargs... )
13891314 var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing (state)
13901315 invalidate_cache! (tearing_reassemble (
13911316 state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1392- simplify, cse_hack, array_hack, fully_determined))
1317+ simplify, array_hack, fully_determined))
13931318end
13941319
13951320"""
@@ -1399,7 +1324,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
13991324the system is balanced.
14001325"""
14011326function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1402- mm = nothing , cse_hack = true , array_hack = true , fully_determined = true , kwargs... )
1327+ mm = nothing , array_hack = true , fully_determined = true , kwargs... )
14031328 jac = let state = state
14041329 (eqs, vars) -> begin
14051330 symeqs = EquationsView (state)[eqs]
@@ -1425,5 +1350,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
14251350 state, jac; state_priority,
14261351 kwargs... )
14271352 tearing_reassemble (state, var_eq_matching, full_var_eq_matching, var_sccs;
1428- simplify, mm, cse_hack, array_hack, fully_determined)
1353+ simplify, mm, array_hack, fully_determined)
14291354end
0 commit comments