230
230
=#
231
231
232
232
function tearing_reassemble (state:: TearingState , var_eq_matching,
233
- full_var_eq_matching = nothing ; simplify = false , mm = nothing )
233
+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
234
234
@unpack fullvars, sys, structure = state
235
235
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236
236
extra_vars = Int[]
@@ -584,24 +584,48 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584
584
end
585
585
@set! sys. unknowns = unknowns
586
586
587
- # HACK: Since we don't support array equations, any equation of the sort
588
- # `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589
- # calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590
- # for this case to avoid the unnecessary cost.
591
- # This and the below hack are implemented simultaneously
587
+ obs, subeqs = cse_and_array_hacks (
588
+ obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
592
589
590
+ @set! sys. eqs = neweqs
591
+ @set! sys. observed = obs
592
+
593
+ @set! sys. substitutions = Substitutions (subeqs, deps)
594
+
595
+ # Only makes sense for time-dependent
596
+ # TODO : generalize to SDE
597
+ if sys isa ODESystem
598
+ @set! sys. schedule = Schedule (var_eq_matching, dummy_sub)
599
+ end
600
+ sys = schedule (sys)
601
+ @set! state. sys = sys
602
+ @set! sys. tearing_state = state
603
+ return invalidate_cache! (sys)
604
+ end
605
+
606
+ """
607
+ # HACK 1
608
+
609
+ Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610
+ gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611
+ _very_ expensive. this hack performs a limited form of CSE specifically for this case to
612
+ avoid the unnecessary cost. This and the below hack are implemented simultaneously
613
+
614
+ # HACK 2
615
+
616
+ Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617
+ equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618
+ if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619
+ not) we first count the number of times the scalarized form of each observed variable
620
+ occurs in observed equations (and unknowns if it's split).
621
+ """
622
+ function cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse = true , array = true )
623
+ # HACK 1
593
624
# mapping of rhs to temporary CSE variable
594
625
# `f(...) => tmpvar` in above example
595
626
rhs_to_tempvar = Dict ()
596
627
597
- # HACK: Add equations for array observed variables. If `p[i] ~ (...)`
598
- # are equations, add an equation `p ~ [p[1], p[2], ...]`
599
- # allow topsort to reorder them
600
- # only add the new equation if all `p[i]` are present and the unscalarized
601
- # form is used in any equation (observed or not)
602
- # we first count the number of times the scalarized form of each observed
603
- # variable occurs in observed equations (and unknowns if it's split).
604
-
628
+ # HACK 2
605
629
# map of array observed variable (unscalarized) to number of its
606
630
# scalarized terms that appear in observed equations
607
631
arr_obs_occurrences = Dict ()
@@ -613,7 +637,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
613
637
vars! (all_vars, rhs)
614
638
615
639
# HACK 1
616
- if (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
640
+ if cse &&
641
+ (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
617
642
iscall (rhs) && operation (rhs) === getindex &&
618
643
Symbolics. shape (rhs) != Symbolics. Unknown ()
619
644
rhs_arr = arguments (rhs)[1 ]
@@ -643,6 +668,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
643
668
end
644
669
# end HACK 1
645
670
671
+ array || continue
646
672
iscall (lhs) || continue
647
673
operation (lhs) === getindex || continue
648
674
Symbolics. shape (lhs) != Symbolics. Unknown () || continue
@@ -687,20 +713,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
687
713
# need to re-sort subeqs
688
714
subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
689
715
690
- @set! sys. eqs = neweqs
691
- @set! sys. observed = obs
692
-
693
- @set! sys. substitutions = Substitutions (subeqs, deps)
694
-
695
- # Only makes sense for time-dependent
696
- # TODO : generalize to SDE
697
- if sys isa ODESystem
698
- @set! sys. schedule = Schedule (var_eq_matching, dummy_sub)
699
- end
700
- sys = schedule (sys)
701
- @set! state. sys = sys
702
- @set! sys. tearing_state = state
703
- return invalidate_cache! (sys)
716
+ return obs, subeqs
704
717
end
705
718
706
719
# PART OF HACK 1
@@ -733,10 +746,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
733
746
instead, which calls this function internally.
734
747
"""
735
748
function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
736
- simplify = false , kwargs... )
749
+ simplify = false , cse_hack = true , array_hack = true , kwargs... )
737
750
var_eq_matching, full_var_eq_matching = tearing (state)
738
751
invalidate_cache! (tearing_reassemble (
739
- state, var_eq_matching, full_var_eq_matching; mm, simplify))
752
+ state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack ))
740
753
end
741
754
742
755
"""
@@ -758,7 +771,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
758
771
the system is balanced.
759
772
"""
760
773
function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
761
- mm = nothing , kwargs... )
774
+ mm = nothing , cse_hack = true , array_hack = true , kwargs... )
762
775
jac = let state = state
763
776
(eqs, vars) -> begin
764
777
symeqs = EquationsView (state)[eqs]
@@ -782,5 +795,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
782
795
end
783
796
var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
784
797
kwargs... )
785
- tearing_reassemble (state, var_eq_matching; simplify, mm)
798
+ tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack )
786
799
end
0 commit comments