230230=#
231231
232232function tearing_reassemble (state:: TearingState , var_eq_matching,
233- full_var_eq_matching = nothing ; simplify = false , mm = nothing )
233+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
234234 @unpack fullvars, sys, structure = state
235235 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236236 extra_vars = Int[]
@@ -584,24 +584,48 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584 end
585585 @set! sys. unknowns = unknowns
586586
587- # HACK: Since we don't support array equations, any equation of the sort
588- # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589- # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590- # for this case to avoid the unnecessary cost.
591- # This and the below hack are implemented simultaneously
587+ obs, subeqs = cse_and_array_hacks (
588+ obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
592589
590+ @set! sys. eqs = neweqs
591+ @set! sys. observed = obs
592+
593+ @set! sys. substitutions = Substitutions (subeqs, deps)
594+
595+ # Only makes sense for time-dependent
596+ # TODO : generalize to SDE
597+ if sys isa ODESystem
598+ @set! sys. schedule = Schedule (var_eq_matching, dummy_sub)
599+ end
600+ sys = schedule (sys)
601+ @set! state. sys = sys
602+ @set! sys. tearing_state = state
603+ return invalidate_cache! (sys)
604+ end
605+
606+ """
607+ # HACK 1
608+
609+ Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610+ gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611+ _very_ expensive. this hack performs a limited form of CSE specifically for this case to
612+ avoid the unnecessary cost. This and the below hack are implemented simultaneously
613+
614+ # HACK 2
615+
616+ Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617+ equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618+ if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619+ not) we first count the number of times the scalarized form of each observed variable
620+ occurs in observed equations (and unknowns if it's split).
621+ """
622+ function cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse = true , array = true )
623+ # HACK 1
593624 # mapping of rhs to temporary CSE variable
594625 # `f(...) => tmpvar` in above example
595626 rhs_to_tempvar = Dict ()
596627
597- # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
598- # are equations, add an equation `p ~ [p[1], p[2], ...]`
599- # allow topsort to reorder them
600- # only add the new equation if all `p[i]` are present and the unscalarized
601- # form is used in any equation (observed or not)
602- # we first count the number of times the scalarized form of each observed
603- # variable occurs in observed equations (and unknowns if it's split).
604-
628+ # HACK 2
605629 # map of array observed variable (unscalarized) to number of its
606630 # scalarized terms that appear in observed equations
607631 arr_obs_occurrences = Dict ()
@@ -613,7 +637,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
613637 vars! (all_vars, rhs)
614638
615639 # HACK 1
616- if (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
640+ if cse &&
641+ (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
617642 iscall (rhs) && operation (rhs) === getindex &&
618643 Symbolics. shape (rhs) != Symbolics. Unknown ()
619644 rhs_arr = arguments (rhs)[1 ]
@@ -643,6 +668,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
643668 end
644669 # end HACK 1
645670
671+ array || continue
646672 iscall (lhs) || continue
647673 operation (lhs) === getindex || continue
648674 Symbolics. shape (lhs) != Symbolics. Unknown () || continue
@@ -687,20 +713,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
687713 # need to re-sort subeqs
688714 subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
689715
690- @set! sys. eqs = neweqs
691- @set! sys. observed = obs
692-
693- @set! sys. substitutions = Substitutions (subeqs, deps)
694-
695- # Only makes sense for time-dependent
696- # TODO : generalize to SDE
697- if sys isa ODESystem
698- @set! sys. schedule = Schedule (var_eq_matching, dummy_sub)
699- end
700- sys = schedule (sys)
701- @set! state. sys = sys
702- @set! sys. tearing_state = state
703- return invalidate_cache! (sys)
716+ return obs, subeqs
704717end
705718
706719# PART OF HACK 1
@@ -733,10 +746,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
733746instead, which calls this function internally.
734747"""
735748function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
736- simplify = false , kwargs... )
749+ simplify = false , cse_hack = true , array_hack = true , kwargs... )
737750 var_eq_matching, full_var_eq_matching = tearing (state)
738751 invalidate_cache! (tearing_reassemble (
739- state, var_eq_matching, full_var_eq_matching; mm, simplify))
752+ state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack ))
740753end
741754
742755"""
@@ -758,7 +771,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
758771the system is balanced.
759772"""
760773function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
761- mm = nothing , kwargs... )
774+ mm = nothing , cse_hack = true , array_hack = true , kwargs... )
762775 jac = let state = state
763776 (eqs, vars) -> begin
764777 symeqs = EquationsView (state)[eqs]
@@ -782,5 +795,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
782795 end
783796 var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
784797 kwargs... )
785- tearing_reassemble (state, var_eq_matching; simplify, mm)
798+ tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack )
786799end
0 commit comments