Skip to content

Commit 74480ac

Browse files
zhuliquanalamb
andauthored
feat: support normalized expr in CSE (#13315)
* feat: support normalized expr in CSE * feat: support normalize_eq in cse optimization * feat: support cumulative binary expr result in normalize_eq --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 87b77bb commit 74480ac

File tree

4 files changed

+790
-32
lines changed

4 files changed

+790
-32
lines changed

datafusion/common/src/cse.rs

Lines changed: 123 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,33 +50,63 @@ impl<T: HashNode + ?Sized> HashNode for Arc<T> {
5050
}
5151
}
5252

53+
/// The `Normalizeable` trait defines a method to determine whether a node can be normalized.
54+
///
55+
/// Normalization is the process of converting a node into a canonical form that can be used
56+
/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE),
57+
/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal.
58+
pub trait Normalizeable {
59+
fn can_normalize(&self) -> bool;
60+
}
61+
62+
/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing
63+
/// normlized nodes in optimizations like Common Subexpression Elimination (CSE).
64+
///
65+
/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization)
66+
/// are considered equal in CSE optimization, even if their original forms differ.
67+
///
68+
/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their
69+
/// internal representations.
70+
pub trait NormalizeEq: Eq + Normalizeable {
71+
fn normalize_eq(&self, other: &Self) -> bool;
72+
}
73+
5374
/// Identifier that represents a [`TreeNode`] tree.
5475
///
5576
/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and
5677
/// "have no collision (as low as possible)"
57-
#[derive(Debug, Eq, PartialEq)]
58-
struct Identifier<'n, N> {
78+
#[derive(Debug, Eq)]
79+
struct Identifier<'n, N: NormalizeEq> {
5980
// Hash of `node` built up incrementally during the first, visiting traversal.
6081
// Its value is not necessarily equal to default hash of the node. E.g. it is not
6182
// equal to `expr.hash()` if the node is `Expr`.
6283
hash: u64,
6384
node: &'n N,
6485
}
6586

66-
impl<N> Clone for Identifier<'_, N> {
87+
impl<N: NormalizeEq> Clone for Identifier<'_, N> {
6788
fn clone(&self) -> Self {
6889
*self
6990
}
7091
}
71-
impl<N> Copy for Identifier<'_, N> {}
92+
impl<N: NormalizeEq> Copy for Identifier<'_, N> {}
7293

73-
impl<N> Hash for Identifier<'_, N> {
94+
impl<N: NormalizeEq> Hash for Identifier<'_, N> {
7495
fn hash<H: Hasher>(&self, state: &mut H) {
7596
state.write_u64(self.hash);
7697
}
7798
}
7899

79-
impl<'n, N: HashNode> Identifier<'n, N> {
100+
impl<N: NormalizeEq> PartialEq for Identifier<'_, N> {
101+
fn eq(&self, other: &Self) -> bool {
102+
self.hash == other.hash && self.node.normalize_eq(other.node)
103+
}
104+
}
105+
106+
impl<'n, N> Identifier<'n, N>
107+
where
108+
N: HashNode + NormalizeEq,
109+
{
80110
fn new(node: &'n N, random_state: &RandomState) -> Self {
81111
let mut hasher = random_state.build_hasher();
82112
node.hash_node(&mut hasher);
@@ -213,7 +243,11 @@ pub enum FoundCommonNodes<N> {
213243
///
214244
/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier
215245
/// because they should not be recognized as common subtree.
216-
struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
246+
struct CSEVisitor<'a, 'n, N, C>
247+
where
248+
N: NormalizeEq,
249+
C: CSEController<Node = N>,
250+
{
217251
/// statistics of [`TreeNode`]s
218252
node_stats: &'a mut NodeStats<'n, N>,
219253

@@ -244,7 +278,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController<Node = N>> {
244278
}
245279

246280
/// Record item that used when traversing a [`TreeNode`] tree.
247-
enum VisitRecord<'n, N> {
281+
enum VisitRecord<'n, N>
282+
where
283+
N: NormalizeEq,
284+
{
248285
/// Marks the beginning of [`TreeNode`]. It contains:
249286
/// - The post-order index assigned during the first, visiting traversal.
250287
EnterMark(usize),
@@ -258,7 +295,11 @@ enum VisitRecord<'n, N> {
258295
NodeItem(Identifier<'n, N>, bool),
259296
}
260297

261-
impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n, N, C> {
298+
impl<'n, N, C> CSEVisitor<'_, 'n, N, C>
299+
where
300+
N: TreeNode + HashNode + NormalizeEq,
301+
C: CSEController<Node = N>,
302+
{
262303
/// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before
263304
/// it. Returns a tuple that contains:
264305
/// - The pre-order index of the [`TreeNode`] we marked.
@@ -271,17 +312,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
271312
/// information up from children to parents via `visit_stack` during the first,
272313
/// visiting traversal and no need to test the expression's validity beforehand with
273314
/// an extra traversal).
274-
fn pop_enter_mark(&mut self) -> (usize, Option<Identifier<'n, N>>, bool) {
275-
let mut node_id = None;
315+
fn pop_enter_mark(
316+
&mut self,
317+
can_normalize: bool,
318+
) -> (usize, Option<Identifier<'n, N>>, bool) {
319+
let mut node_ids: Vec<Identifier<'n, N>> = vec![];
276320
let mut is_valid = true;
277321

278322
while let Some(item) = self.visit_stack.pop() {
279323
match item {
280324
VisitRecord::EnterMark(down_index) => {
325+
if can_normalize {
326+
node_ids.sort_by_key(|i| i.hash);
327+
}
328+
let node_id = node_ids
329+
.into_iter()
330+
.fold(None, |accum, item| Some(item.combine(accum)));
281331
return (down_index, node_id, is_valid);
282332
}
283333
VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => {
284-
node_id = Some(sub_node_id.combine(node_id));
334+
node_ids.push(sub_node_id);
285335
is_valid &= sub_node_is_valid;
286336
}
287337
}
@@ -290,8 +340,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController<Node = N>> CSEVisitor<'_, 'n,
290340
}
291341
}
292342

293-
impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisitor<'n>
294-
for CSEVisitor<'_, 'n, N, C>
343+
impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C>
344+
where
345+
N: TreeNode + HashNode + NormalizeEq,
346+
C: CSEController<Node = N>,
295347
{
296348
type Node = N;
297349

@@ -331,7 +383,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
331383
}
332384

333385
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
334-
let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark();
386+
let (down_index, sub_node_id, sub_node_is_valid) =
387+
self.pop_enter_mark(node.can_normalize());
335388

336389
let node_id = Identifier::new(node, self.random_state).combine(sub_node_id);
337390
let is_valid = C::is_valid(node) && sub_node_is_valid;
@@ -369,7 +422,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
369422
/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the
370423
/// corresponding temporary [`TreeNode`], that column contains the evaluate result of
371424
/// replaced [`TreeNode`] tree.
372-
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
425+
struct CSERewriter<'a, 'n, N, C>
426+
where
427+
N: NormalizeEq,
428+
C: CSEController<Node = N>,
429+
{
373430
/// statistics of [`TreeNode`]s
374431
node_stats: &'a NodeStats<'n, N>,
375432

@@ -386,8 +443,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
386443
controller: &'a mut C,
387444
}
388445

389-
impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
390-
for CSERewriter<'_, '_, N, C>
446+
impl<N, C> TreeNodeRewriter for CSERewriter<'_, '_, N, C>
447+
where
448+
N: TreeNode + NormalizeEq,
449+
C: CSEController<Node = N>,
391450
{
392451
type Node = N;
393452

@@ -408,13 +467,30 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
408467
self.down_index += 1;
409468
}
410469

411-
let (node, alias) =
412-
self.common_nodes.entry(node_id).or_insert_with(|| {
413-
let node_alias = self.controller.generate_alias();
414-
(node, node_alias)
415-
});
416-
417-
let rewritten = self.controller.rewrite(node, alias);
470+
// We *must* replace all original nodes with same `node_id`, not just the first
471+
// node which is inserted into the common_nodes. This is because nodes with the same
472+
// `node_id` are semantically equivalent, but not exactly the same.
473+
//
474+
// For example, `a + 1` and `1 + a` are semantically equivalent but not identical.
475+
// In this case, we should replace the common expression `1 + a` with a new variable
476+
// (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by
477+
// `__common_cse_1`.
478+
//
479+
// The final result would be:
480+
// - `__common_cse_1 as a + 1`
481+
// - `__common_cse_1 as 1 + a`
482+
//
483+
// This way, we can efficiently handle semantically equivalent expressions without
484+
// incorrectly treating them as identical.
485+
let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id)
486+
{
487+
self.controller.rewrite(&node, alias)
488+
} else {
489+
let node_alias = self.controller.generate_alias();
490+
let rewritten = self.controller.rewrite(&node, &node_alias);
491+
self.common_nodes.insert(node_id, (node, node_alias));
492+
rewritten
493+
};
418494

419495
return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
420496
}
@@ -441,7 +517,11 @@ pub struct CSE<N, C: CSEController<Node = N>> {
441517
controller: C,
442518
}
443519

444-
impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C> {
520+
impl<N, C> CSE<N, C>
521+
where
522+
N: TreeNode + HashNode + Clone + NormalizeEq,
523+
C: CSEController<Node = N>,
524+
{
445525
pub fn new(controller: C) -> Self {
446526
Self {
447527
random_state: RandomState::new(),
@@ -557,6 +637,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
557637
) -> Result<FoundCommonNodes<N>> {
558638
let mut found_common = false;
559639
let mut node_stats = NodeStats::new();
640+
560641
let id_arrays_list = nodes_list
561642
.iter()
562643
.map(|nodes| {
@@ -596,7 +677,10 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
596677
#[cfg(test)]
597678
mod test {
598679
use crate::alias::AliasGenerator;
599-
use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE};
680+
use crate::cse::{
681+
CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq,
682+
Normalizeable, CSE,
683+
};
600684
use crate::tree_node::tests::TestTreeNode;
601685
use crate::Result;
602686
use std::collections::HashSet;
@@ -662,6 +746,18 @@ mod test {
662746
}
663747
}
664748

749+
impl Normalizeable for TestTreeNode<String> {
750+
fn can_normalize(&self) -> bool {
751+
false
752+
}
753+
}
754+
755+
impl NormalizeEq for TestTreeNode<String> {
756+
fn normalize_eq(&self, other: &Self) -> bool {
757+
self == other
758+
}
759+
}
760+
665761
#[test]
666762
fn id_array_visitor() -> Result<()> {
667763
let alias_generator = AliasGenerator::new();

0 commit comments

Comments
 (0)