@@ -50,33 +50,63 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
5050 }
5151}
5252
53+ /// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
54+ ///
55+ /// Normalization is the process of converting a node into a canonical form that can be used
56+ /// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
57+ /// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
58+ pub trait Normalizeable {
59+ fn can_normalize ( & self ) -> bool ;
60+ }
61+
62+ /// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
63+ /// normlized nodes in optimizations like Common Subexpression Elimination (CSE).
64+ ///
65+ /// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
66+ /// are considered equal in CSE optimization, even if their original forms differ.
67+ ///
68+ /// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
69+ /// internal representations.
70+ pub trait NormalizeEq : Eq + Normalizeable {
71+ fn normalize_eq ( & self , other : & Self ) -> bool ;
72+ }
73+
5374/// Identifier that represents a [`TreeNode`] tree.
5475///
5576/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
5677/// "have no collision (as low as possible)"
57- #[ derive( Debug , Eq , PartialEq ) ]
58- struct Identifier < ' n , N > {
78+ #[ derive( Debug , Eq ) ]
79+ struct Identifier < ' n , N : NormalizeEq > {
5980 // Hash of `node` built up incrementally during the first, visiting traversal.
6081 // Its value is not necessarily equal to default hash of the node. E.g. it is not
6182 // equal to `expr.hash()` if the node is `Expr`.
6283 hash : u64 ,
6384 node : & ' n N ,
6485}
6586
66- impl < N > Clone for Identifier < ' _ , N > {
87+ impl < N : NormalizeEq > Clone for Identifier < ' _ , N > {
6788 fn clone ( & self ) -> Self {
6889 * self
6990 }
7091}
71- impl < N > Copy for Identifier < ' _ , N > { }
92+ impl < N : NormalizeEq > Copy for Identifier < ' _ , N > { }
7293
73- impl < N > Hash for Identifier < ' _ , N > {
94+ impl < N : NormalizeEq > Hash for Identifier < ' _ , N > {
7495 fn hash < H : Hasher > ( & self , state : & mut H ) {
7596 state. write_u64 ( self . hash ) ;
7697 }
7798}
7899
79- impl < ' n , N : HashNode > Identifier < ' n , N > {
100+ impl < N : NormalizeEq > PartialEq for Identifier < ' _ , N > {
101+ fn eq ( & self , other : & Self ) -> bool {
102+ self . hash == other. hash && self . node . normalize_eq ( other. node )
103+ }
104+ }
105+
106+ impl < ' n , N > Identifier < ' n , N >
107+ where
108+ N : HashNode + NormalizeEq ,
109+ {
80110 fn new ( node : & ' n N , random_state : & RandomState ) -> Self {
81111 let mut hasher = random_state. build_hasher ( ) ;
82112 node. hash_node ( & mut hasher) ;
@@ -213,7 +243,11 @@ pub enum FoundCommonNodes<N> {
213243///
214244/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
215245/// because they should not be recognized as common subtree.
216- struct CSEVisitor < ' a , ' n , N , C : CSEController < Node = N > > {
246+ struct CSEVisitor < ' a , ' n , N , C >
247+ where
248+ N : NormalizeEq ,
249+ C : CSEController < Node = N > ,
250+ {
217251 /// statistics of [`TreeNode`]s
218252 node_stats : & ' a mut NodeStats < ' n , N > ,
219253
@@ -244,7 +278,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
244278}
245279
246280/// Record item that used when traversing a [`TreeNode`] tree.
247- enum VisitRecord < ' n , N > {
281+ enum VisitRecord < ' n , N >
282+ where
283+ N : NormalizeEq ,
284+ {
248285 /// Marks the beginning of [`TreeNode`]. It contains:
249286 /// - The post-order index assigned during the first, visiting traversal.
250287 EnterMark ( usize ) ,
@@ -258,7 +295,11 @@ enum VisitRecord<'n, N> {
258295 NodeItem ( Identifier < ' n , N > , bool ) ,
259296}
260297
261- impl < ' n , N : TreeNode + HashNode , C : CSEController < Node = N > > CSEVisitor < ' _ , ' n , N , C > {
298+ impl < ' n , N , C > CSEVisitor < ' _ , ' n , N , C >
299+ where
300+ N : TreeNode + HashNode + NormalizeEq ,
301+ C : CSEController < Node = N > ,
302+ {
262303 /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
263304 /// it. Returns a tuple that contains:
264305 /// - The pre-order index of the [`TreeNode`] we marked.
@@ -271,17 +312,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
271312 /// information up from children to parents via `visit_stack` during the first,
272313 /// visiting traversal and no need to test the expression's validity beforehand with
273314 /// an extra traversal).
274- fn pop_enter_mark ( & mut self ) -> ( usize , Option < Identifier < ' n , N > > , bool ) {
275- let mut node_id = None ;
315+ fn pop_enter_mark (
316+ & mut self ,
317+ can_normalize : bool ,
318+ ) -> ( usize , Option < Identifier < ' n , N > > , bool ) {
319+ let mut node_ids: Vec < Identifier < ' n , N > > = vec ! [ ] ;
276320 let mut is_valid = true ;
277321
278322 while let Some ( item) = self . visit_stack . pop ( ) {
279323 match item {
280324 VisitRecord :: EnterMark ( down_index) => {
325+ if can_normalize {
326+ node_ids. sort_by_key ( |i| i. hash ) ;
327+ }
328+ let node_id = node_ids
329+ . into_iter ( )
330+ . fold ( None , |accum, item| Some ( item. combine ( accum) ) ) ;
281331 return ( down_index, node_id, is_valid) ;
282332 }
283333 VisitRecord :: NodeItem ( sub_node_id, sub_node_is_valid) => {
284- node_id = Some ( sub_node_id . combine ( node_id ) ) ;
334+ node_ids . push ( sub_node_id ) ;
285335 is_valid &= sub_node_is_valid;
286336 }
287337 }
@@ -290,8 +340,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
290340 }
291341}
292342
293- impl < ' n , N : TreeNode + HashNode + Eq , C : CSEController < Node = N > > TreeNodeVisitor < ' n >
294- for CSEVisitor < ' _ , ' n , N , C >
343+ impl < ' n , N , C > TreeNodeVisitor < ' n > for CSEVisitor < ' _ , ' n , N , C >
344+ where
345+ N : TreeNode + HashNode + NormalizeEq ,
346+ C : CSEController < Node = N > ,
295347{
296348 type Node = N ;
297349
@@ -331,7 +383,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
331383 }
332384
333385 fn f_up ( & mut self , node : & ' n Self :: Node ) -> Result < TreeNodeRecursion > {
334- let ( down_index, sub_node_id, sub_node_is_valid) = self . pop_enter_mark ( ) ;
386+ let ( down_index, sub_node_id, sub_node_is_valid) =
387+ self . pop_enter_mark ( node. can_normalize ( ) ) ;
335388
336389 let node_id = Identifier :: new ( node, self . random_state ) . combine ( sub_node_id) ;
337390 let is_valid = C :: is_valid ( node) && sub_node_is_valid;
@@ -369,7 +422,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
369422/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
370423/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
371424/// replaced [`TreeNode`] tree.
372- struct CSERewriter < ' a , ' n , N , C : CSEController < Node = N > > {
425+ struct CSERewriter < ' a , ' n , N , C >
426+ where
427+ N : NormalizeEq ,
428+ C : CSEController < Node = N > ,
429+ {
373430 /// statistics of [`TreeNode`]s
374431 node_stats : & ' a NodeStats < ' n , N > ,
375432
@@ -386,8 +443,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
386443 controller : & ' a mut C ,
387444}
388445
389- impl < N : TreeNode + Eq , C : CSEController < Node = N > > TreeNodeRewriter
390- for CSERewriter < ' _ , ' _ , N , C >
446+ impl < N , C > TreeNodeRewriter for CSERewriter < ' _ , ' _ , N , C >
447+ where
448+ N : TreeNode + NormalizeEq ,
449+ C : CSEController < Node = N > ,
391450{
392451 type Node = N ;
393452
@@ -408,13 +467,30 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
408467 self . down_index += 1 ;
409468 }
410469
411- let ( node, alias) =
412- self . common_nodes . entry ( node_id) . or_insert_with ( || {
413- let node_alias = self . controller . generate_alias ( ) ;
414- ( node, node_alias)
415- } ) ;
416-
417- let rewritten = self . controller . rewrite ( node, alias) ;
470+ // We *must* replace all original nodes with same `node_id`, not just the first
471+ // node which is inserted into the common_nodes. This is because nodes with the same
472+ // `node_id` are semantically equivalent, but not exactly the same.
473+ //
474+ // For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
475+ // In this case, we should replace the common expression `1 + a` with a new variable
476+ // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
477+ // `__common_cse_1`.
478+ //
479+ // The final result would be:
480+ // - `__common_cse_1 as a + 1`
481+ // - `__common_cse_1 as 1 + a`
482+ //
483+ // This way, we can efficiently handle semantically equivalent expressions without
484+ // incorrectly treating them as identical.
485+ let rewritten = if let Some ( ( _, alias) ) = self . common_nodes . get ( & node_id)
486+ {
487+ self . controller . rewrite ( & node, alias)
488+ } else {
489+ let node_alias = self . controller . generate_alias ( ) ;
490+ let rewritten = self . controller . rewrite ( & node, & node_alias) ;
491+ self . common_nodes . insert ( node_id, ( node, node_alias) ) ;
492+ rewritten
493+ } ;
418494
419495 return Ok ( Transformed :: new ( rewritten, true , TreeNodeRecursion :: Jump ) ) ;
420496 }
@@ -441,7 +517,11 @@ pub struct CSE<N, C: CSEController<Node = N>> {
441517 controller : C ,
442518}
443519
444- impl < N : TreeNode + HashNode + Clone + Eq , C : CSEController < Node = N > > CSE < N , C > {
520+ impl < N , C > CSE < N , C >
521+ where
522+ N : TreeNode + HashNode + Clone + NormalizeEq ,
523+ C : CSEController < Node = N > ,
524+ {
445525 pub fn new ( controller : C ) -> Self {
446526 Self {
447527 random_state : RandomState :: new ( ) ,
@@ -557,6 +637,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
557637 ) -> Result < FoundCommonNodes < N > > {
558638 let mut found_common = false ;
559639 let mut node_stats = NodeStats :: new ( ) ;
640+
560641 let id_arrays_list = nodes_list
561642 . iter ( )
562643 . map ( |nodes| {
@@ -596,7 +677,10 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
596677#[ cfg( test) ]
597678mod test {
598679 use crate :: alias:: AliasGenerator ;
599- use crate :: cse:: { CSEController , HashNode , IdArray , Identifier , NodeStats , CSE } ;
680+ use crate :: cse:: {
681+ CSEController , HashNode , IdArray , Identifier , NodeStats , NormalizeEq ,
682+ Normalizeable , CSE ,
683+ } ;
600684 use crate :: tree_node:: tests:: TestTreeNode ;
601685 use crate :: Result ;
602686 use std:: collections:: HashSet ;
@@ -662,6 +746,18 @@ mod test {
662746 }
663747 }
664748
749+ impl Normalizeable for TestTreeNode < String > {
750+ fn can_normalize ( & self ) -> bool {
751+ false
752+ }
753+ }
754+
755+ impl NormalizeEq for TestTreeNode < String > {
756+ fn normalize_eq ( & self , other : & Self ) -> bool {
757+ self == other
758+ }
759+ }
760+
665761 #[ test]
666762 fn id_array_visitor ( ) -> Result < ( ) > {
667763 let alias_generator = AliasGenerator :: new ( ) ;
0 commit comments