Skip to content

Commit 6c3576f

Browse files
feat: add simple CSE for array scalarization case
1 parent 5c2edf4 commit 6c3576f

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584
end
585585
@set! sys.unknowns = unknowns
586586

587+
# HACK: Since we don't support array equations, any equation of the sort
588+
# `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589+
# calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590+
# for this case to avoid the unnecessary cost.
591+
# This and the below hack are implemented simultaneously
592+
593+
# mapping of rhs to temporary CSE variable
594+
# `f(...) => tmpvar` in above example
595+
rhs_to_tempvar = Dict()
596+
587597
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588598
# are equations, add an equation `p ~ [p[1], p[2], ...]`
589599
# allow topsort to reorder them
@@ -597,9 +607,42 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
597607
arr_obs_occurrences = Dict()
598608
# to check if array variables occur in unscalarized form anywhere
599609
all_vars = Set()
600-
for eq in obs
601-
vars!(all_vars, eq.rhs)
610+
for (i, eq) in enumerate(obs)
602611
lhs = eq.lhs
612+
rhs = eq.rhs
613+
vars!(all_vars, rhs)
614+
615+
# HACK 1
616+
if (!ModelingToolkit.isvariable(rhs) || ModelingToolkit.iscalledparameter(rhs)) &&
617+
iscall(rhs) && operation(rhs) === getindex &&
618+
Symbolics.shape(rhs) != Symbolics.Unknown()
619+
rhs_arr = arguments(rhs)[1]
620+
if !haskey(rhs_to_tempvar, rhs_arr)
621+
tempvar = gensym(Symbol(lhs))
622+
N = length(rhs_arr)
623+
tempvar = unwrap(Symbolics.variable(
624+
tempvar; T = Symbolics.symtype(rhs_arr)))
625+
tempvar = setmetadata(
626+
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
627+
tempeq = tempvar ~ rhs_arr
628+
rhs_to_tempvar[rhs_arr] = tempvar
629+
push!(obs, tempeq)
630+
push!(subeqs, tempeq)
631+
end
632+
633+
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
634+
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
635+
# which fails the topological sort
636+
neweq = lhs ~ getindex_wrapper(
637+
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
638+
obs[i] = neweq
639+
subeqi = findfirst(isequal(eq), subeqs)
640+
if subeqi !== nothing
641+
subeqs[subeqi] = neweq
642+
end
643+
end
644+
# end HACK 1
645+
603646
iscall(lhs) || continue
604647
operation(lhs) === getindex || continue
605648
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
@@ -640,6 +683,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
640683
end
641684
append!(obs, obs_arr_eqs)
642685
append!(subeqs, obs_arr_eqs)
686+
643687
# need to re-sort subeqs
644688
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
645689

@@ -659,6 +703,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
659703
return invalidate_cache!(sys)
660704
end
661705

706+
# PART OF HACK 1
707+
getindex_wrapper(x, i) = x[i...]
708+
709+
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
710+
711+
# PART OF HACK 2
662712
function change_origin(origin, arr)
663713
return origin(arr)
664714
end

test/structural_transformation/utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,25 @@ end
4848
@mtkbuild sys = ODESystem(
4949
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
5050
@test length(equations(sys)) == 1
51-
@test length(observed(sys)) == 6
51+
@test length(observed(sys)) == 7
5252
@test any(eq -> isequal(eq.lhs, y), observed(sys))
5353
@test any(eq -> isequal(eq.lhs, z), observed(sys))
5454
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
5555
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
5656
end
57+
58+
@testset "scalarized array observed calling same function multiple times" begin
59+
@variables x(t) y(t)[1:2]
60+
@parameters foo(::Real)[1:2]
61+
val = Ref(0)
62+
function _tmp_fn2(x)
63+
val[] += 1
64+
return [x, 2x]
65+
end
66+
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
67+
@test length(equations(sys)) == 1
68+
@test length(observed(sys)) == 3
69+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
70+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
71+
@test val[] == 1
72+
end

0 commit comments

Comments
 (0)