Skip to content

Commit d2fe0eb

Browse files
feat: allow CSE and array hacks to be disabled
1 parent 79a7fc9 commit d2fe0eb

File tree

2 files changed

+90
-33
lines changed

2 files changed

+90
-33
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ end
230230
=#
231231

232232
function tearing_reassemble(state::TearingState, var_eq_matching,
233-
full_var_eq_matching = nothing; simplify = false, mm = nothing)
233+
full_var_eq_matching = nothing; simplify = false, mm = nothing, cse_hack = true, array_hack = true)
234234
@unpack fullvars, sys, structure = state
235235
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236236
extra_vars = Int[]
@@ -584,24 +584,48 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584
end
585585
@set! sys.unknowns = unknowns
586586

587-
# HACK: Since we don't support array equations, any equation of the sort
588-
# `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589-
# calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590-
# for this case to avoid the unnecessary cost.
591-
# This and the below hack are implemented simultaneously
587+
obs, subeqs = cse_and_array_hacks(
588+
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
592589

590+
@set! sys.eqs = neweqs
591+
@set! sys.observed = obs
592+
593+
@set! sys.substitutions = Substitutions(subeqs, deps)
594+
595+
# Only makes sense for time-dependent
596+
# TODO: generalize to SDE
597+
if sys isa ODESystem
598+
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
599+
end
600+
sys = schedule(sys)
601+
@set! state.sys = sys
602+
@set! sys.tearing_state = state
603+
return invalidate_cache!(sys)
604+
end
605+
606+
"""
607+
# HACK 1
608+
609+
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610+
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611+
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
612+
avoid the unnecessary cost. This and the below hack are implemented simultaneously
613+
614+
# HACK 2
615+
616+
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617+
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618+
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619+
not) we first count the number of times the scalarized form of each observed variable
620+
occurs in observed equations (and unknowns if it's split).
621+
"""
622+
function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array = true)
623+
# HACK 1
593624
# mapping of rhs to temporary CSE variable
594625
# `f(...) => tmpvar` in above example
595626
rhs_to_tempvar = Dict()
596627

597-
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
598-
# are equations, add an equation `p ~ [p[1], p[2], ...]`
599-
# allow topsort to reorder them
600-
# only add the new equation if all `p[i]` are present and the unscalarized
601-
# form is used in any equation (observed or not)
602-
# we first count the number of times the scalarized form of each observed
603-
# variable occurs in observed equations (and unknowns if it's split).
604-
628+
# HACK 2
605629
# map of array observed variable (unscalarized) to number of its
606630
# scalarized terms that appear in observed equations
607631
arr_obs_occurrences = Dict()
@@ -613,7 +637,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
613637
vars!(all_vars, rhs)
614638

615639
# HACK 1
616-
if (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
640+
if cse &&
641+
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
617642
iscall(rhs) && operation(rhs) === getindex &&
618643
Symbolics.shape(rhs) != Symbolics.Unknown()
619644
rhs_arr = arguments(rhs)[1]
@@ -643,6 +668,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
643668
end
644669
# end HACK 1
645670

671+
array || continue
646672
iscall(lhs) || continue
647673
operation(lhs) === getindex || continue
648674
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
@@ -687,20 +713,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
687713
# need to re-sort subeqs
688714
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
689715

690-
@set! sys.eqs = neweqs
691-
@set! sys.observed = obs
692-
693-
@set! sys.substitutions = Substitutions(subeqs, deps)
694-
695-
# Only makes sense for time-dependent
696-
# TODO: generalize to SDE
697-
if sys isa ODESystem
698-
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
699-
end
700-
sys = schedule(sys)
701-
@set! state.sys = sys
702-
@set! sys.tearing_state = state
703-
return invalidate_cache!(sys)
716+
return obs, subeqs
704717
end
705718

706719
# PART OF HACK 1
@@ -733,10 +746,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
733746
instead, which calls this function internally.
734747
"""
735748
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
736-
simplify = false, kwargs...)
749+
simplify = false, cse_hack = true, array_hack = true, kwargs...)
737750
var_eq_matching, full_var_eq_matching = tearing(state)
738751
invalidate_cache!(tearing_reassemble(
739-
state, var_eq_matching, full_var_eq_matching; mm, simplify))
752+
state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack))
740753
end
741754

742755
"""
@@ -758,7 +771,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
758771
the system is balanced.
759772
"""
760773
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
761-
mm = nothing, kwargs...)
774+
mm = nothing, cse_hack = true, array_hack = true, kwargs...)
762775
jac = let state = state
763776
(eqs, vars) -> begin
764777
symeqs = EquationsView(state)[eqs]
@@ -782,5 +795,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
782795
end
783796
var_eq_matching = dummy_derivative_graph!(state, jac; state_priority,
784797
kwargs...)
785-
tearing_reassemble(state, var_eq_matching; simplify, mm)
798+
tearing_reassemble(state, var_eq_matching; simplify, mm, cse_hack, array_hack)
786799
end

test/structural_transformation/utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,47 @@ end
8686
StructuralTransformations.change_origin]
8787
end
8888
end
89+
90+
@testset "array and cse hacks can be disabled" begin
91+
@testset "fully_determined = true" begin
92+
@variables x(t) y(t)[1:2] z(t)[1:2]
93+
@parameters foo(::AbstractVector)[1:2]
94+
_tmp_fn(x) = 2x
95+
@named sys = ODESystem(
96+
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
97+
98+
sys1 = structural_simplify(sys; cse_hack = false)
99+
@test length(observed(sys1)) == 6
100+
@test !any(observed(sys1)) do eq
101+
iscall(eq.rhs) &&
102+
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
103+
end
104+
105+
sys2 = structural_simplify(sys; array_hack = false)
106+
@test length(observed(sys2)) == 5
107+
@test !any(observed(sys2)) do eq
108+
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
109+
end
110+
end
111+
112+
@testset "fully_determined = false" begin
113+
@variables x(t) y(t)[1:2] z(t)[1:2] w(t)
114+
@parameters foo(::AbstractVector)[1:2]
115+
_tmp_fn(x) = 2x
116+
@named sys = ODESystem(
117+
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
118+
119+
sys1 = structural_simplify(sys; cse_hack = false, fully_determined = false)
120+
@test length(observed(sys1)) == 6
121+
@test !any(observed(sys1)) do eq
122+
iscall(eq.rhs) &&
123+
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
124+
end
125+
126+
sys2 = structural_simplify(sys; array_hack = false, fully_determined = false)
127+
@test length(observed(sys2)) == 5
128+
@test !any(observed(sys2)) do eq
129+
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
130+
end
131+
end
132+
end

0 commit comments

Comments
 (0)