Skip to content

Commit 3658f4d

Browse files
feat: add proper scoping to CSE
1 parent 3603725 commit 3658f4d

File tree

2 files changed

+49
-21
lines changed

2 files changed

+49
-21
lines changed

src/code.jl

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,21 @@ end
746746

747747
CSEState() = CSEState(Union{Assignment, DestructuredArgs}[], IdDict())
748748

749-
Base.copy(x::CSEState) = CSEState(copy(x.sorted_eprs), copy(x.visited))
749+
Base.copy(x::CSEState) = CSEState(copy(x.sorted_exprs), copy(x.visited))
750+
751+
"""
752+
$(TYPEDSIGNATURES)
753+
754+
Return a `CSEState` for a new scope inside the one represented by `state`. The new
755+
`CSEState` will use previously-CSEd bindings for expressions only involving variables
756+
outside the new scope, but will generate new bindings for variables defined in this scope.
757+
The new bindings will not affect the outer scope.
758+
"""
759+
function new_scope(state::CSEState)
760+
state = copy(state)
761+
empty!(state.sorted_exprs)
762+
return state
763+
end
750764

751765
"""
752766
$(SIGNATURES)
@@ -772,17 +786,29 @@ end
772786
"""
773787
$(TYPEDSIGNATURES)
774788
775-
Perform Common Subexpression Elimination on the given expression `expr`. Return a `Let`
776-
that computes `expr` with CSE.
789+
Perform Common Subexpression Elimination on the given expression `expr`. Return an
790+
equivalent `expr` with optimized computation.
777791
"""
778792
function cse(expr)
779793
state = CSEState()
780794
newexpr = cse!(expr, state)
781-
if newexpr isa Func # special-case `Func` because wrapping it in a `Let` makes no sense
782-
return Func(newexpr.args, newexpr.kwargs, Let(state.sorted_exprs, newexpr.body, false), newexpr.pre)
783-
else
784-
return Let(state.sorted_exprs, newexpr, false)
785-
end
795+
return apply_cse(newexpr, state)
796+
end
797+
798+
"""
799+
$(TYPEDSIGNATURES)
800+
801+
Given a CSEd expression `newexpr` and the corresponding `state`, return an equivalent
802+
expression with optimized computation.
803+
804+
This is also used when introducing new scopes in subexpressions.
805+
"""
806+
function apply_cse(newexpr, state::CSEState)
807+
# we special-case an empty `sorted_exprs` because e.g. if `expr` is a `Func`, it will
808+
# introduce a new scope and not add bindings to `state`. We don't want to wrap it
809+
# in a `Let`.
810+
isempty(state.sorted_exprs) && return newexpr
811+
return Let(state.sorted_exprs, newexpr, false)
786812
end
787813

788814
"""
@@ -877,24 +903,26 @@ function cse!(x::DestructuredArgs, state::CSEState)
877903
end
878904

879905
function cse!(x::Let, state::CSEState)
880-
# if we just CSE the body, the CSE subexpressions will be wrapped around this Let,
881-
# but will depend on variables defined in the assignments of this Let which is
882-
# incorrect. Instead, we include `x.pairs` in `state.sorted_exprs`, and return
883-
# a `Let` with no assignments and a CSEd body.
906+
state = new_scope(state)
907+
# `Let` introduces a new scope. For each assignment `p` in `x.pairs`, we CSE it
908+
# and then append it to the new assignments from CSE. This is because the assignments
909+
# are imperative, so the CSE assignments for a given `p` can include previous `p`,
910+
# preventing us from simply wrapping the `Let` in another `Let`.
884911
for p in x.pairs
885-
# cse the assignments individually too
886912
newp = cse!(p, state)
887913
push!(state.sorted_exprs, newp)
888914
end
889-
return Let([], cse!(x.body, state), x.let_block)
915+
newbody = cse!(x.body, state)
916+
return Let(state.sorted_exprs, newbody, x.let_block)
890917
end
891918

892919
function cse!(x::Func, state::CSEState)
893-
return Func(x.args, x.kwargs, cse!(x.body, state), x.pre)
920+
state = new_scope(state)
921+
return Func(x.args, x.kwargs, apply_cse(cse!(x.body, state), state), x.pre)
894922
end
895923

896924
function cse!(x::AtIndex, state::CSEState)
897-
return AtIndex(x.i, cse!(x.elem, state))
925+
return AtIndex(cse!(x.i, state), cse!(x.elem, state))
898926
end
899927

900928
function cse!(x::MakeTuple, state::CSEState)

test/cse.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using SymbolicUtils, SymbolicUtils.Code, Test
1+
using SymbolicUtils, SymbolicUtils.Code, SparseArrays, Test
2+
using SymbolicUtils.Code: topological_sort
23
using RuntimeGeneratedFunctions
34

45
RuntimeGeneratedFunctions.init(@__MODULE__)
@@ -45,8 +46,7 @@ end
4546
sorted_nodes = topological_sort(expr)
4647
@test isempty(sorted_nodes)
4748
let_expr = cse(expr)
48-
@test isempty(let_expr.pairs)
49-
@test isequal(let_expr.body, a)
49+
@test isequal(let_expr, expr)
5050

5151
# array symbolics
5252
# https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739
@@ -71,9 +71,9 @@ end
7171

7272
@testset "Expr" begin
7373
ex = :(a^2 + sin(a^2))
74-
@test isequal(cse(ex).body, ex)
74+
@test isequal(cse(ex), ex)
7575
ex = LiteralExpr(ex)
76-
@test isequal(cse(ex).body, ex)
76+
@test isequal(cse(ex), ex)
7777
end
7878

7979
@testset "Tuple" begin

0 commit comments

Comments
 (0)