737737
738738# ## Common subexprssion evaluation
739739
740- @inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
740+ """
741+ newsym!(state::CSEState, ::Type{T})
742+
743+ Generates new symbol of type `T` with unique name in `state`.
744+ """
745+ @inline function newsym! (state, :: Type{T} ) where T
746+ name = " ##cse#$(state. varid[]) "
747+ state. varid[] += 1
748+ Sym {T} (Symbol (name))
749+ end
741750
742751"""
743752 $(TYPEDSIGNATURES)
@@ -769,11 +778,16 @@ struct CSEState
769778 A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it.
770779 """
771780 visited:: IdDict{Any, Any}
781+ """
782+ Integer counter, used to generate unique names for intermediate variables.
783+ """
784+ varid:: Ref{Int}
772785end
773786
774- CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict ())
787+ CSEState () = CSEState (Union{Assignment, DestructuredArgs}[], IdDict (), Ref ( 1 ) )
775788
776- Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited))
789+ # the copy still references the same `varid` Ref to work in nested scopes
790+ Base. copy (x:: CSEState ) = CSEState (copy (x. sorted_exprs), copy (x. visited), x. varid)
777791
778792"""
779793 $(TYPEDSIGNATURES)
@@ -861,13 +875,13 @@ function cse!(expr::Symbolic, state::CSEState)
861875 (_is_array_of_symbolics (arg) || _is_tuple_of_symbolics (arg))
862876 if arg isa Tuple
863877 new_arg = cse! (MakeTuple (arg), state)
864- sym = newsym ( Tuple{symtype .(arg)... })
878+ sym = newsym! (state, Tuple{symtype .(arg)... })
865879 elseif issparse (arg)
866880 new_arg = cse! (MakeSparseArray (arg), state)
867- sym = newsym ( AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
881+ sym = newsym! (state, AbstractSparseArray{symtype (eltype (arg)), indextype (arg), ndims (arg)})
868882 else
869883 new_arg = cse! (MakeArray (arg, typeof (arg)), state)
870- sym = newsym ( AbstractArray{symtype (eltype (arg)), ndims (arg)})
884+ sym = newsym! (state, AbstractArray{symtype (eltype (arg)), ndims (arg)})
871885 end
872886 push! (state. sorted_exprs, sym ← new_arg)
873887 state. visited[arg] = sym
@@ -878,7 +892,7 @@ function cse!(expr::Symbolic, state::CSEState)
878892 # use `term` instead of `maketerm` because we only care about the operation being performed
879893 # and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
880894 new_expr = term (operation (expr), args... ; type = symtype (expr))
881- sym = newsym ( symtype (new_expr))
895+ sym = newsym! (state, symtype (new_expr))
882896 push! (state. sorted_exprs, sym ← new_expr)
883897 return sym
884898 end
0 commit comments