@@ -708,7 +708,7 @@ Update the system equations, unknowns, and observables after simplification.
708708"""
709709function update_simplified_system! (
710710 state:: TearingState , neweqs, solved_eqs, dummy_sub, var_eq_matching, extra_unknowns;
711- cse_hack = true , array_hack = true )
711+ array_hack = true )
712712 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = state. structure
713713 diff_to_var = invview (var_to_diff)
714714
@@ -732,8 +732,8 @@ function update_simplified_system!(
732732 unknowns = [unknowns; extra_unknowns]
733733 @set! sys. unknowns = unknowns
734734
735- obs, subeqs, deps = cse_and_array_hacks (
736- sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
735+ obs, subeqs, deps = array_var_hack (
736+ sys, obs, solved_eqs, unknowns, neweqs; array = array_hack)
737737
738738 @set! sys. eqs = neweqs
739739 @set! sys. observed = obs
@@ -775,7 +775,7 @@ appear in the system. Algebraic variables are variables that are not
775775differential variables.
776776"""
777777function tearing_reassemble (state:: TearingState , var_eq_matching,
778- full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
778+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , array_hack = true )
779779 extra_vars = Int[]
780780 if full_var_eq_matching != = nothing
781781 for v in 𝑑vertices (state. structure. graph)
@@ -811,21 +811,14 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
811811 state, var_eq_matching, eq_ordering, var_ordering, nelim_eq, nelim_var)
812812
813813 sys = update_simplified_system! (state, neweqs, solved_eqs, dummy_sub, var_eq_matching,
814- extra_unknowns; cse_hack, array_hack)
814+ extra_unknowns; array_hack)
815815
816816 @set! state. sys = sys
817817 @set! sys. tearing_state = state
818818 return invalidate_cache! (sys)
819819end
820820
821821"""
822- # HACK 1
823-
824- Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
825- gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
826- _very_ expensive. this hack performs a limited form of CSE specifically for this case to
827- avoid the unnecessary cost. This and the below hack are implemented simultaneously
828-
829822# HACK 2
830823
831824Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
@@ -834,12 +827,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
834827not) we first count the number of times the scalarized form of each observed variable
835828occurs in observed equations (and unknowns if it's split).
836829"""
837- function cse_and_array_hacks (sys, obs, subeqs, unknowns, neweqs; cse = true , array = true )
838- # HACK 1
839- # mapping of rhs to temporary CSE variable
840- # `f(...) => tmpvar` in above example
841- rhs_to_tempvar = Dict ()
842-
830+ function array_var_hack (sys, obs, subeqs, unknowns, neweqs; array = true )
843831 # HACK 2
844832 # map of array observed variable (unscalarized) to number of its
845833 # scalarized terms that appear in observed equations
@@ -851,36 +839,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
851839 rhs = eq. rhs
852840 vars! (all_vars, rhs)
853841
854- # HACK 1
855- if cse && is_getindexed_array (rhs)
856- rhs_arr = arguments (rhs)[1 ]
857- iscall (rhs_arr) && operation (rhs_arr) isa Symbolics. Operator && continue
858- if ! haskey (rhs_to_tempvar, rhs_arr)
859- tempvar = gensym (Symbol (lhs))
860- N = length (rhs_arr)
861- tempvar = unwrap (Symbolics. variable (
862- tempvar; T = Symbolics. symtype (rhs_arr)))
863- tempvar = setmetadata (
864- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
865- tempeq = tempvar ~ rhs_arr
866- rhs_to_tempvar[rhs_arr] = tempvar
867- push! (obs, tempeq)
868- push! (subeqs, tempeq)
869- end
870-
871- # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
872- # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
873- # which fails the topological sort
874- neweq = lhs ~ getindex_wrapper (
875- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
876- obs[i] = neweq
877- subeqi = findfirst (isequal (eq), subeqs)
878- if subeqi != = nothing
879- subeqs[subeqi] = neweq
880- end
881- end
882- # end HACK 1
883-
884842 array || continue
885843 iscall (lhs) || continue
886844 operation (lhs) === getindex || continue
@@ -891,33 +849,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
891849 continue
892850 end
893851
894- # Also do CSE for `equations(sys)`
895- if cse
896- for (i, eq) in enumerate (neweqs)
897- (; lhs, rhs) = eq
898- is_getindexed_array (rhs) || continue
899- rhs_arr = arguments (rhs)[1 ]
900- if ! haskey (rhs_to_tempvar, rhs_arr)
901- tempvar = gensym (Symbol (lhs))
902- N = length (rhs_arr)
903- tempvar = unwrap (Symbolics. variable (
904- tempvar; T = Symbolics. symtype (rhs_arr)))
905- tempvar = setmetadata (
906- tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
907- vars! (all_vars, rhs_arr)
908- tempeq = tempvar ~ rhs_arr
909- rhs_to_tempvar[rhs_arr] = tempvar
910- push! (obs, tempeq)
911- push! (subeqs, tempeq)
912- end
913- # don't need getindex_wrapper, but do it anyway to know that this
914- # hack took place
915- neweq = lhs ~ getindex_wrapper (
916- rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
917- neweqs[i] = neweq
918- end
919- end
920-
921852 # count variables in unknowns if they are scalarized forms of variables
922853 # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
923854 # is an observed equation.
@@ -1007,10 +938,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
1007938instead, which calls this function internally.
1008939"""
1009940function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
1010- simplify = false , cse_hack = true , array_hack = true , kwargs... )
941+ simplify = false , array_hack = true , kwargs... )
1011942 var_eq_matching, full_var_eq_matching = tearing (state)
1012943 invalidate_cache! (tearing_reassemble (
1013- state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
944+ state, var_eq_matching, full_var_eq_matching; mm, simplify, array_hack))
1014945end
1015946
1016947"""
@@ -1032,7 +963,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
1032963the system is balanced.
1033964"""
1034965function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
1035- mm = nothing , cse_hack = true , array_hack = true , kwargs... )
966+ mm = nothing , array_hack = true , kwargs... )
1036967 jac = let state = state
1037968 (eqs, vars) -> begin
1038969 symeqs = EquationsView (state)[eqs]
@@ -1056,5 +987,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1056987 end
1057988 var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
1058989 kwargs... )
1059- tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack)
990+ tearing_reassemble (state, var_eq_matching; simplify, mm, array_hack)
1060991end
0 commit comments