@@ -5,11 +5,12 @@ using TermInterface
55
66export toexpr, Assignment, (← ), Let, Func, DestructuredArgs, LiteralExpr,
77 SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
8- SpawnFetch, Multithreaded
8+ SpawnFetch, Multithreaded, cse
99
1010import .. SymbolicUtils
11+ import .. SymbolicUtils. Rewriters
1112import SymbolicUtils: @matchable , Sym, Term, istree, operation, arguments,
12- symtype
13+ symtype, similarterm, unsorted_arguments, metadata
1314
1415# #== state management ==##
1516
@@ -610,4 +611,95 @@ function toexpr(exp::LiteralExpr, st)
610611 recurse_expr (exp. ex, st)
611612end
612613
614+
615+ # ## Code-related utilities
616+
617+ # ## Common subexprssion evaluation
618+
619+ @inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
620+
621+ function _cse! (mem, expr)
622+ istree (expr) || return expr
623+ op = _cse! (mem, operation (expr))
624+ args = map (Base. Fix1 (_cse!, mem), arguments (expr))
625+ t = similarterm (expr, op, args)
626+
627+ v, dict = mem
628+ update! = let v= v, t= t
629+ () -> begin
630+ var = newsym (symtype (t))
631+ push! (v, var ← t)
632+ length (v)
633+ end
634+ end
635+ v[get! (update!, dict, t)]. lhs
636+ end
637+
638+ function cse (expr)
639+ state = Dict {Any, Int} ()
640+ cse_state! (state, expr)
641+ cse_block (state, expr)
642+ end
643+
644+
645+ function _cse (exprs:: AbstractArray )
646+ letblock = cse (Term {Any} (tuple, exprs))
647+ letblock. pairs, arguments (letblock. body)
648+ end
649+
650+ function cse (x:: MakeArray )
651+ assigns, expr = _cse (x. elems)
652+ Let (assigns, MakeArray (expr, x. similarto, x. output_eltype))
653+ end
654+
655+ function cse (x:: SetArray )
656+ assigns, expr = _cse (x. elems)
657+ Let (assigns, SetArray (x. inbounds, x. arr, expr))
658+ end
659+
660+ function cse (x:: MakeSparseArray )
661+ sp = x. array
662+ assigns, expr = _cse (sp. nzval)
663+ if sp isa SparseMatrixCSC
664+ Let (assigns, MakeSparseArray (SparseMatrixCSC (sp. m, sp. n,
665+ sp. colptr, sp. rowval, exprs)))
666+ else
667+ Let (assigns, MakeSparseArray (SparseVector (sp. n, sp. nzinds, exprs)))
668+ end
669+ end
670+
671+
672+ function cse_state! (state, t)
673+ ! istree (t) && return t
674+ state[t] = Base. get! (state, t, 0 ) + 1
675+ foreach (x-> cse_state! (state, x), unsorted_arguments (t))
676+ end
677+
678+ function cse_block! (assignments, counter, names, name, state, x)
679+ if get (state, x, 0 ) > 1
680+ if haskey (names, x)
681+ return names[x]
682+ else
683+ sym = Sym {symtype(x)} (Symbol (name, counter[]))
684+ names[x] = sym
685+ push! (assignments, sym ← x)
686+ counter[] += 1
687+ return sym
688+ end
689+ elseif istree (x)
690+ args = map (a-> cse_block! (assignments, counter, names, name, state,a), unsorted_arguments (x))
691+ return similarterm (x, operation (x), args, symtype (x),
692+ metadata= metadata (x))
693+ else
694+ return x
695+ end
696+ end
697+
698+ function cse_block (state, t, name= Symbol (" var-" , hash (t)))
699+ assignments = Assignment[]
700+ counter = Ref {Int} (1 )
701+ names = Dict {Any, Sym} ()
702+ Let (assignments, cse_block! (assignments, counter, names, name, state, t))
703+ end
704+
613705end
0 commit comments