11module Code
22
3- using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
3+ using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
4+ DocStringExtensions
45
56export toexpr, Assignment, (← ), Let, Func, DestructuredArgs, LiteralExpr,
67 SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
696697
697698@inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
698699
700+ """
701+ $(SIGNATURES)
702+
703+ Perform a topological sort on a symbolic expression represented as a Directed Acyclic
704+ Graph (DAG).
705+
706+ This function takes a symbolic expression `graph` (potentially containing shared common
707+ sub-expressions) and returns an array of `Assignment` objects. Each `Assignment`
708+ represents a node in the sorted order, assigning a fresh symbol to its corresponding
709+ expression. The order ensures that all dependencies of a node appear before the node itself
710+ in the array.
711+
712+ Hash consing is assumed, meaning that structurally identical expressions are represented by
713+ the same object in memory. This allows for efficient equality checks using `IdDict`.
714+ """
715+ function topological_sort (graph)
716+ sorted_nodes = Assignment[]
717+ visited = IdDict ()
718+
719+ function dfs (node)
720+ if haskey (visited, node)
721+ return visited[node]
722+ end
723+ if iscall (node)
724+ args = map (dfs, arguments (node))
725+ new_node = maketerm (typeof (node), operation (node), args, metadata (node))
726+ sym = newsym (symtype (new_node))
727+ push! (sorted_nodes, sym ← new_node)
728+ visited[node] = sym
729+ return sym
730+ elseif _is_array_of_symbolics (node)
731+ new_node = map (dfs, node)
732+ sym = newsym (typeof (new_node))
733+ push! (sorted_nodes, sym ← new_node)
734+ visited[node] = sym
735+ return sym
736+ else
737+ visited[node] = node
738+ return node
739+ end
740+ end
741+
742+ dfs (graph)
743+ return sorted_nodes
744+ end
745+
699746function _cse! (mem, expr)
700747 iscall (expr) || return expr
701748 op = _cse! (mem, operation (expr))
@@ -714,12 +761,16 @@ function _cse!(mem, expr)
714761end
715762
716763function cse (expr)
717- state = Dict {Any, Int} ()
718- cse_state! (state, expr)
719- cse_block (state, expr)
764+ sorted_nodes = topological_sort (expr)
765+ if isempty (sorted_nodes)
766+ return Let (Assignment[], expr)
767+ else
768+ last_assignment = pop! (sorted_nodes)
769+ body = rhs (last_assignment)
770+ return Let (sorted_nodes, body)
771+ end
720772end
721773
722-
723774function _cse (exprs:: AbstractArray )
724775 letblock = cse (Term {Any} (tuple, vec (exprs)))
725776 letblock. pairs, reshape (arguments (letblock. body), size (exprs))
@@ -746,41 +797,4 @@ function cse(x::MakeSparseArray)
746797 end
747798end
748799
749-
750- function cse_state! (state, t)
751- ! iscall (t) && return t
752- state[t] = Base. get (state, t, 0 ) + 1
753- foreach (x-> cse_state! (state, x), arguments (t))
754- end
755-
756- function cse_block! (assignments, counter, names, name, state, x)
757- if get (state, x, 0 ) > 1
758- if haskey (names, x)
759- return names[x]
760- else
761- sym = Sym {symtype(x)} (Symbol (name, counter[]))
762- names[x] = sym
763- push! (assignments, sym ← x)
764- counter[] += 1
765- return sym
766- end
767- elseif iscall (x)
768- args = map (a-> cse_block! (assignments, counter, names, name, state,a), arguments (x))
769- if isterm (x)
770- return term (operation (x), args... )
771- else
772- return maketerm (typeof (x), operation (x), args, metadata (x))
773- end
774- else
775- return x
776- end
777- end
778-
779- function cse_block (state, t, name= Symbol (" var-" , hash (t)))
780- assignments = Assignment[]
781- counter = Ref {Int} (1 )
782- names = Dict {Any, BasicSymbolic} ()
783- Let (assignments, cse_block! (assignments, counter, names, name, state, t))
784- end
785-
786800end
0 commit comments