Skip to content

Commit 1025363

Browse files
feat: extend CSE hack to non-observed equations
1 parent d2fe0eb commit 1025363

File tree

3 files changed

+79
-12
lines changed

3 files changed

+79
-12
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584
end
585585
@set! sys.unknowns = unknowns
586586

587-
obs, subeqs = cse_and_array_hacks(
587+
obs, subeqs, deps = cse_and_array_hacks(
588588
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
589589

590590
@set! sys.eqs = neweqs
@@ -637,10 +637,7 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
637637
vars!(all_vars, rhs)
638638

639639
# HACK 1
640-
if cse &&
641-
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
642-
iscall(rhs) && operation(rhs) === getindex &&
643-
Symbolics.shape(rhs) != Symbolics.Unknown()
640+
if cse && is_getindexed_array(rhs)
644641
rhs_arr = arguments(rhs)[1]
645642
if !haskey(rhs_to_tempvar, rhs_arr)
646643
tempvar = gensym(Symbol(lhs))
@@ -677,6 +674,33 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
677674
arr_obs_occurrences[arg1] = cnt + 1
678675
continue
679676
end
677+
678+
# Also do CSE for `equations(sys)`
679+
if cse
680+
for (i, eq) in enumerate(neweqs)
681+
(; lhs, rhs) = eq
682+
is_getindexed_array(rhs) || continue
683+
rhs_arr = arguments(rhs)[1]
684+
if !haskey(rhs_to_tempvar, rhs_arr)
685+
tempvar = gensym(Symbol(lhs))
686+
N = length(rhs_arr)
687+
tempvar = unwrap(Symbolics.variable(
688+
tempvar; T = Symbolics.symtype(rhs_arr)))
689+
tempvar = setmetadata(
690+
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
691+
tempeq = tempvar ~ rhs_arr
692+
rhs_to_tempvar[rhs_arr] = tempvar
693+
push!(obs, tempeq)
694+
push!(subeqs, tempeq)
695+
end
696+
# don't need getindex_wrapper, but do it anyway to know that this
697+
# hack took place
698+
neweq = lhs ~ getindex_wrapper(
699+
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
700+
neweqs[i] = neweq
701+
end
702+
end
703+
680704
# count variables in unknowns if they are scalarized forms of variables
681705
# also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
682706
# is an observed equation.
@@ -713,7 +737,16 @@ function cse_and_array_hacks(obs, subeqs, unknowns, neweqs; cse = true, array =
713737
# need to re-sort subeqs
714738
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
715739

716-
return obs, subeqs
740+
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
741+
for i in 1:length(subeqs)]
742+
743+
return obs, subeqs, deps
744+
end
745+
746+
function is_getindexed_array(rhs)
747+
(!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
748+
iscall(rhs) && operation(rhs) === getindex &&
749+
Symbolics.shape(rhs) != Symbolics.Unknown()
717750
end
718751

719752
# PART OF HACK 1

src/systems/nonlinear/initializesystem.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ function generate_initializesystem(sys::ODESystem;
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), kwargs...)
15-
trueobs = unhack_observed(observed(sys))
15+
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
1616
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
1717
vars_set = Set(vars) # for efficient in-lookup
1818

19-
eqs = equations(sys)
2019
idxs_diff = isdiffeq.(eqs)
2120
idxs_alge = .!idxs_diff
2221

@@ -329,11 +328,11 @@ end
329328
Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
330329
initialization.
331330
"""
332-
function unhack_observed(eqs::Vector{Equation})
331+
function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation})
333332
subs = Dict()
334333
tempvars = Set()
335334
rm_idxs = Int[]
336-
for (i, eq) in enumerate(eqs)
335+
for (i, eq) in enumerate(obseqs)
337336
iscall(eq.rhs) || continue
338337
if operation(eq.rhs) == StructuralTransformations.change_origin
339338
push!(rm_idxs, i)
@@ -347,14 +346,27 @@ function unhack_observed(eqs::Vector{Equation})
347346
end
348347

349348
for (i, eq) in enumerate(eqs)
349+
iscall(eq.rhs) || continue
350+
if operation(eq.rhs) == StructuralTransformations.getindex_wrapper
351+
var, idxs = arguments(eq.rhs)
352+
subs[eq.rhs] = var[idxs...]
353+
push!(tempvars, var)
354+
end
355+
end
356+
357+
for (i, eq) in enumerate(obseqs)
350358
if eq.lhs in tempvars
351359
subs[eq.lhs] = eq.rhs
352360
push!(rm_idxs, i)
353361
end
354362
end
355363

356-
eqs = eqs[setdiff(eachindex(eqs), rm_idxs)]
357-
return map(eqs) do eq
364+
obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)]
365+
obseqs = map(obseqs) do eq
366+
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
367+
end
368+
eqs = map(eqs) do eq
358369
fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs)
359370
end
371+
return obseqs, eqs
360372
end

test/structural_transformation/utils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ end
8585
iscall(eq.rhs) && operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
8686
StructuralTransformations.change_origin]
8787
end
88+
89+
@testset "CSE hack in equations(sys)" begin
90+
val[] = 0
91+
@variables z(t)[1:2]
92+
@mtkbuild sys = ODESystem(
93+
[D(y) ~ foo(x), D(x) ~ sum(y), zeros(2) ~ foo(prod(z))], t)
94+
@test length(equations(sys)) == 5
95+
@test length(observed(sys)) == 2
96+
prob = ODEProblem(
97+
sys, [y => ones(2), z => 2ones(2), x => 3.0], (0.0, 1.0), [foo => _tmp_fn2])
98+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
99+
@test val[] == 2
100+
101+
isys = ModelingToolkit.generate_initializesystem(sys)
102+
@test length(unknowns(isys)) == 5
103+
@test length(equations(isys)) == 2
104+
@test !any(equations(isys)) do eq
105+
iscall(eq.rhs) &&
106+
operation(eq.rhs) in [StructuralTransformations.getindex_wrapper,
107+
StructuralTransformations.change_origin]
108+
end
109+
end
88110
end
89111

90112
@testset "array and cse hacks can be disabled" begin

0 commit comments

Comments
 (0)