Skip to content

Commit e9b91d2

Browse files
refactor: remove CSE hack
1 parent 66cc813 commit e9b91d2

File tree

4 files changed

+47
-168
lines changed

4 files changed

+47
-168
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 11 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ Update the system equations, unknowns, and observables after simplification.
896896
"""
897897
function update_simplified_system!(
898898
state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns;
899-
cse_hack = true, array_hack = true)
899+
array_hack = true)
900900
@unpack solvable_graph, var_to_diff, eq_to_diff, graph = state.structure
901901
diff_to_var = invview(var_to_diff)
902902

@@ -920,8 +920,7 @@ function update_simplified_system!(
920920
unknowns = [unknowns; extra_unknowns]
921921
@set! sys.unknowns = unknowns
922922

923-
obs = cse_and_array_hacks(
924-
sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
923+
obs = tearing_hacks(sys, obs, unknowns, neweqs; array = array_hack)
925924

926925
@set! sys.eqs = neweqs
927926
@set! sys.observed = obs
@@ -977,7 +976,7 @@ differential variables.
977976
according to `full_var_eq_matching`.
978977
"""
979978
function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
980-
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm, cse_hack = true,
979+
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify = false, mm,
981980
array_hack = true, fully_determined = true)
982981
extra_eqs_vars = get_extra_eqs_vars(state, full_var_eq_matching, fully_determined)
983982
neweqs = collect(equations(state))
@@ -1010,7 +1009,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching,
10101009
# var_eq_matching and full_var_eq_matching are now invalidated
10111010

10121011
sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs,
1013-
extra_unknowns; cse_hack, array_hack)
1012+
extra_unknowns; array_hack)
10141013

10151014
@set! state.sys = sys
10161015
@set! sys.tearing_state = state
@@ -1047,60 +1046,22 @@ function get_extra_eqs_vars(
10471046
end
10481047

10491048
"""
1050-
# HACK 1
1051-
1052-
Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
1053-
gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
1054-
_very_ expensive. this hack performs a limited form of CSE specifically for this case to
1055-
avoid the unnecessary cost. This and the below hack are implemented simultaneously
1056-
1057-
# HACK 2
1049+
# HACK
10581050
10591051
Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
10601052
equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
10611053
if all `p[i]` are present and the unscalarized form is used in any equation (observed or
10621054
not) we first count the number of times the scalarized form of each observed variable
10631055
occurs in observed equations (and unknowns if it's split).
10641056
"""
1065-
function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true)
1066-
# HACK 1
1067-
# mapping of rhs to temporary CSE variable
1068-
# `f(...) => tmpvar` in above example
1069-
rhs_to_tempvar = Dict()
1070-
1071-
# HACK 2
1057+
function tearing_hacks(sys, obs, unknowns, neweqs; array = true)
10721058
# map of array observed variable (unscalarized) to number of its
10731059
# scalarized terms that appear in observed equations
10741060
arr_obs_occurrences = Dict()
10751061
for (i, eq) in enumerate(obs)
10761062
lhs = eq.lhs
10771063
rhs = eq.rhs
10781064

1079-
# HACK 1
1080-
if cse && is_getindexed_array(rhs)
1081-
rhs_arr = arguments(rhs)[1]
1082-
iscall(rhs_arr) && operation(rhs_arr) isa Symbolics.Operator && continue
1083-
if !haskey(rhs_to_tempvar, rhs_arr)
1084-
tempvar = gensym(Symbol(lhs))
1085-
N = length(rhs_arr)
1086-
tempvar = unwrap(Symbolics.variable(
1087-
tempvar; T = Symbolics.symtype(rhs_arr)))
1088-
tempvar = setmetadata(
1089-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
1090-
tempeq = tempvar ~ rhs_arr
1091-
rhs_to_tempvar[rhs_arr] = tempvar
1092-
push!(obs, tempeq)
1093-
end
1094-
1095-
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
1096-
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
1097-
# which fails the topological sort
1098-
neweq = lhs ~ getindex_wrapper(
1099-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
1100-
obs[i] = neweq
1101-
end
1102-
# end HACK 1
1103-
11041065
array || continue
11051066
iscall(lhs) || continue
11061067
operation(lhs) === getindex || continue
@@ -1111,31 +1072,6 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
11111072
continue
11121073
end
11131074

1114-
# Also do CSE for `equations(sys)`
1115-
if cse
1116-
for (i, eq) in enumerate(neweqs)
1117-
(; lhs, rhs) = eq
1118-
is_getindexed_array(rhs) || continue
1119-
rhs_arr = arguments(rhs)[1]
1120-
if !haskey(rhs_to_tempvar, rhs_arr)
1121-
tempvar = gensym(Symbol(lhs))
1122-
N = length(rhs_arr)
1123-
tempvar = unwrap(Symbolics.variable(
1124-
tempvar; T = Symbolics.symtype(rhs_arr)))
1125-
tempvar = setmetadata(
1126-
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
1127-
tempeq = tempvar ~ rhs_arr
1128-
rhs_to_tempvar[rhs_arr] = tempvar
1129-
push!(obs, tempeq)
1130-
end
1131-
# don't need getindex_wrapper, but do it anyway to know that this
1132-
# hack took place
1133-
neweq = lhs ~ getindex_wrapper(
1134-
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
1135-
neweqs[i] = neweq
1136-
end
1137-
end
1138-
11391075
# count variables in unknowns if they are scalarized forms of variables
11401076
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
11411077
# is an observed equation.
@@ -1170,18 +1106,7 @@ function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = tru
11701106
return obs
11711107
end
11721108

1173-
function is_getindexed_array(rhs)
1174-
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
1175-
iscall(rhs) && operation(rhs) === getindex &&
1176-
Symbolics.shape(rhs) != Symbolics.Unknown()
1177-
end
1178-
1179-
# PART OF HACK 1
1180-
getindex_wrapper(x, i) = x[i...]
1181-
1182-
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
1183-
1184-
# PART OF HACK 2
1109+
# PART OF HACK
11851110
function change_origin(origin, arr)
11861111
if all(isone, Tuple(origin))
11871112
return arr
@@ -1209,11 +1134,11 @@ new residual equations after tearing. End users are encouraged to call [`mtkcomp
12091134
instead, which calls this function internally.
12101135
"""
12111136
function tearing(sys::AbstractSystem, state = TearingState(sys); mm = nothing,
1212-
simplify = false, cse_hack = true, array_hack = true, fully_determined = true, kwargs...)
1137+
simplify = false, array_hack = true, fully_determined = true, kwargs...)
12131138
var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate = tearing(state)
12141139
invalidate_cache!(tearing_reassemble(
12151140
state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1216-
simplify, cse_hack, array_hack, fully_determined))
1141+
simplify, array_hack, fully_determined))
12171142
end
12181143

12191144
"""
@@ -1223,7 +1148,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
12231148
the system is balanced.
12241149
"""
12251150
function dummy_derivative(sys, state = TearingState(sys); simplify = false,
1226-
mm = nothing, cse_hack = true, array_hack = true, fully_determined = true, kwargs...)
1151+
mm = nothing, array_hack = true, fully_determined = true, kwargs...)
12271152
jac = let state = state
12281153
(eqs, vars) -> begin
12291154
symeqs = EquationsView(state)[eqs]
@@ -1249,5 +1174,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
12491174
state, jac; state_priority,
12501175
kwargs...)
12511176
tearing_reassemble(state, var_eq_matching, full_var_eq_matching, var_sccs;
1252-
simplify, mm, cse_hack, array_hack, fully_determined)
1177+
simplify, mm, array_hack, fully_determined)
12531178
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -780,20 +780,6 @@ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
780780
push!(rm_idxs, i)
781781
continue
782782
end
783-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
784-
var, idxs = arguments(eq.rhs)
785-
subs[eq.rhs] = var[idxs...]
786-
push!(tempvars, var)
787-
end
788-
end
789-
790-
for (i, eq) in enumerate(eqs)
791-
iscall(eq.rhs) || continue
792-
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
793-
var, idxs = arguments(eq.rhs)
794-
subs[eq.rhs] = var[idxs...]
795-
push!(tempvars, var)
796-
end
797783
end
798784

799785
for (i, eq) in enumerate(obseqs)

test/code_generation.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,34 @@ end
7979
@test SciMLBase.successful_retcode(sol)
8080
end
8181
end
82+
83+
@testset "scalarized array observed calling same function multiple times" begin
84+
@variables x(t) y(t)[1:2]
85+
@parameters foo(::Real)[1:2]
86+
val = Ref(0)
87+
function _tmp_fn2(x)
88+
val[] += 1
89+
return [x, 2x]
90+
end
91+
@mtkcompile sys = System([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
92+
@test length(equations(sys)) == 1
93+
@test length(observed(sys)) == 3
94+
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn2], (0.0, 1.0))
95+
val[] = 0
96+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
97+
@test val[] == 1
98+
99+
@testset "CSE in equations(sys)" begin
100+
val[] = 0
101+
@variables z(t)[1:2]
102+
@mtkcompile sys = System(
103+
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
104+
@test length(equations(sys)) == 5
105+
@test length(observed(sys)) == 0
106+
prob = ODEProblem(
107+
sys, [y => ones(2), z => 2ones(2), x => 3.0, foo => _tmp_fn2], (0.0, 1.0))
108+
val[] = 0
109+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
110+
@test val[] == 2
111+
end
112+
end

test/structural_transformation/utils.jl

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
@mtkcompile sys = System(
5353
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
5454
@test length(equations(sys)) == 1
55-
@test length(observed(sys)) == 7
55+
@test length(observed(sys)) == 6
5656
@test any(obs -> isequal(obs, y), observables(sys))
5757
@test any(obs -> isequal(obs, z), observables(sys))
5858
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn], (0.0, 1.0))
@@ -62,76 +62,20 @@ end
6262
@test length(unknowns(isys)) == 5
6363
@test length(equations(isys)) == 4
6464
@test !any(equations(isys)) do eq
65-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
66-
StructuralTransformations.change_origin]
65+
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.change_origin]
6766
end
6867
end
6968

70-
@testset "scalarized array observed calling same function multiple times" begin
71-
@variables x(t) y(t)[1:2]
72-
@parameters foo(::Real)[1:2]
73-
val = Ref(0)
74-
function _tmp_fn2(x)
75-
val[] += 1
76-
return [x, 2x]
77-
end
78-
@mtkcompile sys = System([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
79-
@test length(equations(sys)) == 1
80-
@test length(observed(sys)) == 4
81-
prob = ODEProblem(sys, [x => 1.0, foo => _tmp_fn2], (0.0, 1.0))
82-
val[] = 0
83-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
84-
@test val[] == 1
85-
86-
isys = ModelingToolkit.generate_initializesystem(sys)
87-
@test length(unknowns(isys)) == 3
88-
@test length(equations(isys)) == 2
89-
@test !any(equations(isys)) do eq
90-
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
91-
StructuralTransformations.change_origin]
92-
end
93-
94-
@testset "CSE hack in equations(sys)" begin
95-
val[] = 0
96-
@variables z(t)[1:2]
97-
@mtkcompile sys = System(
98-
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
99-
@test length(equations(sys)) == 5
100-
@test length(observed(sys)) == 2
101-
prob = ODEProblem(
102-
sys, [y => ones(2), z => 2ones(2), x => 3.0, foo => _tmp_fn2], (0.0, 1.0))
103-
val[] = 0
104-
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
105-
@test val[] == 2
106-
107-
isys = ModelingToolkit.generate_initializesystem(sys)
108-
@test length(unknowns(isys)) == 5
109-
@test length(equations(isys)) == 2
110-
@test !any(equations(isys)) do eq
111-
iscall(eq.rhs) &&
112-
operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
113-
StructuralTransformations.change_origin]
114-
end
115-
end
116-
end
117-
118-
@testset "array and cse hacks can be disabled" begin
69+
@testset "array hack can be disabled" begin
11970
@testset "fully_determined = true" begin
12071
@variables x(t) y(t)[1:2] z(t)[1:2]
12172
@parameters foo(::AbstractVector)[1:2]
12273
_tmp_fn(x) = 2x
12374
@named sys = System(
12475
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
12576

126-
sys1 = mtkcompile(sys; cse_hack = false)
127-
@test length(observed(sys1)) == 6
128-
@test !any(observed(sys1)) do eq
129-
iscall(eq.rhs) &&
130-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
131-
end
132-
13377
sys2 = mtkcompile(sys; array_hack = false)
134-
@test length(observed(sys2)) == 5
78+
@test length(observed(sys2)) == 4
13579
@test !any(observed(sys2)) do eq
13680
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
13781
end
@@ -144,15 +88,8 @@ end
14488
@named sys = System(
14589
[D(x) ~ z[1] + z[2] + foo(z)[1] + w, y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
14690

147-
sys1 = mtkcompile(sys; cse_hack = false, fully_determined = false)
148-
@test length(observed(sys1)) == 6
149-
@test !any(observed(sys1)) do eq
150-
iscall(eq.rhs) &&
151-
operation(eq.rhs) == StructuralTransformations.getindex_wrapper
152-
end
153-
15491
sys2 = mtkcompile(sys; array_hack = false, fully_determined = false)
155-
@test length(observed(sys2)) == 5
92+
@test length(observed(sys2)) == 4
15693
@test !any(observed(sys2)) do eq
15794
iscall(eq.rhs) && operation(eq.rhs) == StructuralTransformations.change_origin
15895
end

0 commit comments

Comments
 (0)