Skip to content

Commit ad5664a

Browse files
authored
Merge pull request #200 from JuliaSymbolics/s/cse
CSE
2 parents 701636b + c4e2ec2 commit ad5664a

File tree

3 files changed

+105
-2
lines changed

3 files changed

+105
-2
lines changed

src/code.jl

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ using TermInterface
55

66
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
77
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
8-
SpawnFetch, Multithreaded
8+
SpawnFetch, Multithreaded, cse
99

1010
import ..SymbolicUtils
11+
import ..SymbolicUtils.Rewriters
1112
import SymbolicUtils: @matchable, Sym, Term, istree, operation, arguments,
12-
symtype
13+
symtype, similarterm, unsorted_arguments, metadata
1314

1415
##== state management ==##
1516

@@ -610,4 +611,95 @@ function toexpr(exp::LiteralExpr, st)
610611
recurse_expr(exp.ex, st)
611612
end
612613

614+
615+
### Code-related utilities
616+
617+
### Common subexprssion evaluation
618+
619+
@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))
620+
621+
function _cse!(mem, expr)
622+
istree(expr) || return expr
623+
op = _cse!(mem, operation(expr))
624+
args = map(Base.Fix1(_cse!, mem), arguments(expr))
625+
t = similarterm(expr, op, args)
626+
627+
v, dict = mem
628+
update! = let v=v, t=t
629+
() -> begin
630+
var = newsym(symtype(t))
631+
push!(v, var t)
632+
length(v)
633+
end
634+
end
635+
v[get!(update!, dict, t)].lhs
636+
end
637+
638+
function cse(expr)
639+
state = Dict{Any, Int}()
640+
cse_state!(state, expr)
641+
cse_block(state, expr)
642+
end
643+
644+
645+
function _cse(exprs::AbstractArray)
646+
letblock = cse(Term{Any}(tuple, exprs))
647+
letblock.pairs, arguments(letblock.body)
648+
end
649+
650+
function cse(x::MakeArray)
651+
assigns, expr = _cse(x.elems)
652+
Let(assigns, MakeArray(expr, x.similarto, x.output_eltype))
653+
end
654+
655+
function cse(x::SetArray)
656+
assigns, expr = _cse(x.elems)
657+
Let(assigns, SetArray(x.inbounds, x.arr, expr))
658+
end
659+
660+
function cse(x::MakeSparseArray)
661+
sp = x.array
662+
assigns, expr = _cse(sp.nzval)
663+
if sp isa SparseMatrixCSC
664+
Let(assigns, MakeSparseArray(SparseMatrixCSC(sp.m, sp.n,
665+
sp.colptr, sp.rowval, exprs)))
666+
else
667+
Let(assigns, MakeSparseArray(SparseVector(sp.n, sp.nzinds, exprs)))
668+
end
669+
end
670+
671+
672+
function cse_state!(state, t)
673+
!istree(t) && return t
674+
state[t] = Base.get!(state, t, 0) + 1
675+
foreach(x->cse_state!(state, x), unsorted_arguments(t))
676+
end
677+
678+
function cse_block!(assignments, counter, names, name, state, x)
679+
if get(state, x, 0) > 1
680+
if haskey(names, x)
681+
return names[x]
682+
else
683+
sym = Sym{symtype(x)}(Symbol(name, counter[]))
684+
names[x] = sym
685+
push!(assignments, sym x)
686+
counter[] += 1
687+
return sym
688+
end
689+
elseif istree(x)
690+
args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x))
691+
return similarterm(x, operation(x), args, symtype(x),
692+
metadata=metadata(x))
693+
else
694+
return x
695+
end
696+
end
697+
698+
function cse_block(state, t, name=Symbol("var-", hash(t)))
699+
assignments = Assignment[]
700+
counter = Ref{Int}(1)
701+
names = Dict{Any, Sym}()
702+
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
703+
end
704+
613705
end

test/cse.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using SymbolicUtils, SymbolicUtils.Code, Test
2+
@testset "CSE" begin
3+
@syms x
4+
t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x))))
5+
6+
@test t isa Let
7+
@test length(t.pairs) == 2
8+
@test occursin(t.pairs[1].lhs, t.body)
9+
@test occursin(t.pairs[2].lhs, t.body)
10+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ else
3030
include("rewrite.jl")
3131
include("rulesets.jl")
3232
include("code.jl")
33+
include("cse.jl")
3334
include("interface.jl")
3435
include("fuzz.jl")
3536
include("adjoints.jl")

0 commit comments

Comments
 (0)