Skip to content

Commit a45e2d0

Browse files
feat: add simple CSE for array scalarization case
1 parent a5d7a48 commit a45e2d0

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,16 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
584584
end
585585
@set! sys.unknowns = unknowns
586586

587+
# HACK: Since we don't support array equations, any equation of the sort
588+
# `x[1:n] ~ f(...)[1:n]` gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly
589+
# calling `f` gets _very_ expensive. this hack performs a limited form of CSE specifically
590+
# for this case to avoid the unnecessary cost.
591+
# This and the below hack are implemented simultaneously
592+
593+
# mapping of rhs to temporary CSE variable
594+
# `f(...) => tmpvar` in above example
595+
rhs_to_tempvar = Dict()
596+
587597
# HACK: Add equations for array observed variables. If `p[i] ~ (...)`
588598
# are equations, add an equation `p ~ [p[1], p[2], ...]`
589599
# allow topsort to reorder them
@@ -597,9 +607,41 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
597607
arr_obs_occurrences = Dict()
598608
# to check if array variables occur in unscalarized form anywhere
599609
all_vars = Set()
600-
for eq in obs
601-
vars!(all_vars, eq.rhs)
610+
for (i, eq) in enumerate(obs)
602611
lhs = eq.lhs
612+
rhs = eq.rhs
613+
vars!(all_vars, rhs)
614+
615+
# HACK 1
616+
if !ModelingToolkit.isvariable(rhs) && iscall(rhs) && operation(rhs) === getindex &&
617+
Symbolics.shape(rhs) != Symbolics.Unknown()
618+
rhs_arr = arguments(rhs)[1]
619+
if !haskey(rhs_to_tempvar, rhs_arr)
620+
tempvar = gensym(Symbol(lhs))
621+
N = length(rhs_arr)
622+
tempvar = unwrap(Symbolics.variable(
623+
tempvar; T = Symbolics.symtype(rhs_arr)))
624+
tempvar = setmetadata(
625+
tempvar, Symbolics.ArrayShapeCtx, Symbolics.shape(rhs_arr))
626+
tempeq = tempvar ~ rhs_arr
627+
rhs_to_tempvar[rhs_arr] = tempvar
628+
push!(obs, tempeq)
629+
push!(subeqs, tempeq)
630+
end
631+
632+
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
633+
# so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
634+
# which fails the topological sort
635+
neweq = lhs ~ getindex_wrapper(
636+
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
637+
obs[i] = neweq
638+
subeqi = findfirst(isequal(eq), subeqs)
639+
if subeqi !== nothing
640+
subeqs[subeqi] = neweq
641+
end
642+
end
643+
# end HACK 1
644+
603645
iscall(lhs) || continue
604646
operation(lhs) === getindex || continue
605647
Symbolics.shape(lhs) != Symbolics.Unknown() || continue
@@ -640,6 +682,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
640682
end
641683
append!(obs, obs_arr_eqs)
642684
append!(subeqs, obs_arr_eqs)
685+
643686
# need to re-sort subeqs
644687
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
645688

@@ -659,6 +702,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
659702
return invalidate_cache!(sys)
660703
end
661704

705+
# PART OF HACK 1
706+
getindex_wrapper(x, i) = x[i...]
707+
708+
@register_symbolic getindex_wrapper(x::AbstractArray, i::Tuple{Vararg{Int}})
709+
710+
# PART OF HACK 2
662711
function change_origin(origin, arr)
663712
return origin(arr)
664713
end

test/structural_transformation/utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,25 @@ end
4848
@mtkbuild sys = ODESystem(
4949
[D(x) ~ z[1] + z[2] + foo(z)[1], y[1] ~ 2t, y[2] ~ 3t, z ~ foo(y)], t)
5050
@test length(equations(sys)) == 1
51-
@test length(observed(sys)) == 6
51+
@test length(observed(sys)) == 7
5252
@test any(eq -> isequal(eq.lhs, y), observed(sys))
5353
@test any(eq -> isequal(eq.lhs, z), observed(sys))
5454
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn])
5555
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
5656
end
57+
58+
@testset "scalarized array observed calling same function multiple times" begin
59+
@variables x(t) y(t)[1:2]
60+
@parameters foo(::Real)[1:2]
61+
val = Ref(0)
62+
function _tmp_fn2(x)
63+
val[] += 1
64+
return [x, 2x]
65+
end
66+
@mtkbuild sys = ODESystem([D(x) ~ y[1] + y[2], y ~ foo(x)], t)
67+
@test length(equations(sys)) == 1
68+
@test length(observed(sys)) == 3
69+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [foo => _tmp_fn2])
70+
@test_nowarn prob.f(prob.u0, prob.p, 0.0)
71+
@test val[] == 1
72+
end

0 commit comments

Comments
 (0)