@@ -626,6 +626,59 @@ impl EquivalenceGroup {
626626 JoinType :: RightSemi | JoinType :: RightAnti => right_equivalences. clone ( ) ,
627627 }
628628 }
629+
630+ /// Checks if two expressions are equal either directly or through equivalence classes.
631+ /// For complex expressions (e.g. a + b), checks that the expression trees are structurally
632+ /// identical and their leaf nodes are equivalent either directly or through equivalence classes.
633+ pub fn exprs_equal (
634+ & self ,
635+ left : & Arc < dyn PhysicalExpr > ,
636+ right : & Arc < dyn PhysicalExpr > ,
637+ ) -> bool {
638+ // Direct equality check
639+ if left. eq ( right) {
640+ return true ;
641+ }
642+
643+ // Check if expressions are equivalent through equivalence classes
644+ // We need to check both directions since expressions might be in different classes
645+ if let Some ( left_class) = self . get_equivalence_class ( left) {
646+ if left_class. contains ( right) {
647+ return true ;
648+ }
649+ }
650+ if let Some ( right_class) = self . get_equivalence_class ( right) {
651+ if right_class. contains ( left) {
652+ return true ;
653+ }
654+ }
655+
656+ // For non-leaf nodes, check structural equality
657+ let left_children = left. children ( ) ;
658+ let right_children = right. children ( ) ;
659+
660+ // If either expression is a leaf node and we haven't found equality yet,
661+ // they must be different
662+ if left_children. is_empty ( ) || right_children. is_empty ( ) {
663+ return false ;
664+ }
665+
666+ // Type equality check through reflection
667+ if left. as_any ( ) . type_id ( ) != right. as_any ( ) . type_id ( ) {
668+ return false ;
669+ }
670+
671+ // Check if the number of children is the same
672+ if left_children. len ( ) != right_children. len ( ) {
673+ return false ;
674+ }
675+
676+ // Check if all children are equal
677+ left_children
678+ . into_iter ( )
679+ . zip ( right_children)
680+ . all ( |( left_child, right_child) | self . exprs_equal ( left_child, right_child) )
681+ }
629682}
630683
631684impl Display for EquivalenceGroup {
@@ -647,9 +700,10 @@ mod tests {
647700
648701 use super :: * ;
649702 use crate :: equivalence:: tests:: create_test_params;
650- use crate :: expressions:: { lit, Literal } ;
703+ use crate :: expressions:: { lit, BinaryExpr , Literal } ;
651704
652705 use datafusion_common:: { Result , ScalarValue } ;
706+ use datafusion_expr:: Operator ;
653707
654708 #[ test]
655709 fn test_bridge_groups ( ) -> Result < ( ) > {
@@ -777,4 +831,159 @@ mod tests {
777831 assert ! ( !cls1. contains_any( & cls3) ) ;
778832 assert ! ( !cls2. contains_any( & cls3) ) ;
779833 }
834+
835+ #[ test]
836+ fn test_exprs_equal ( ) -> Result < ( ) > {
837+ struct TestCase {
838+ left : Arc < dyn PhysicalExpr > ,
839+ right : Arc < dyn PhysicalExpr > ,
840+ expected : bool ,
841+ description : & ' static str ,
842+ }
843+
844+ // Create test columns
845+ let col_a = Arc :: new ( Column :: new ( "a" , 0 ) ) as Arc < dyn PhysicalExpr > ;
846+ let col_b = Arc :: new ( Column :: new ( "b" , 1 ) ) as Arc < dyn PhysicalExpr > ;
847+ let col_x = Arc :: new ( Column :: new ( "x" , 2 ) ) as Arc < dyn PhysicalExpr > ;
848+ let col_y = Arc :: new ( Column :: new ( "y" , 3 ) ) as Arc < dyn PhysicalExpr > ;
849+
850+ // Create test literals
851+ let lit_1 =
852+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) as Arc < dyn PhysicalExpr > ;
853+ let lit_2 =
854+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 2 ) ) ) ) as Arc < dyn PhysicalExpr > ;
855+
856+ // Create equivalence group with classes (a = x) and (b = y)
857+ let eq_group = EquivalenceGroup :: new ( vec ! [
858+ EquivalenceClass :: new( vec![ Arc :: clone( & col_a) , Arc :: clone( & col_x) ] ) ,
859+ EquivalenceClass :: new( vec![ Arc :: clone( & col_b) , Arc :: clone( & col_y) ] ) ,
860+ ] ) ;
861+
862+ let test_cases = vec ! [
863+ // Basic equality tests
864+ TestCase {
865+ left: Arc :: clone( & col_a) ,
866+ right: Arc :: clone( & col_a) ,
867+ expected: true ,
868+ description: "Same column should be equal" ,
869+ } ,
870+ // Equivalence class tests
871+ TestCase {
872+ left: Arc :: clone( & col_a) ,
873+ right: Arc :: clone( & col_x) ,
874+ expected: true ,
875+ description: "Columns in same equivalence class should be equal" ,
876+ } ,
877+ TestCase {
878+ left: Arc :: clone( & col_b) ,
879+ right: Arc :: clone( & col_y) ,
880+ expected: true ,
881+ description: "Columns in same equivalence class should be equal" ,
882+ } ,
883+ TestCase {
884+ left: Arc :: clone( & col_a) ,
885+ right: Arc :: clone( & col_b) ,
886+ expected: false ,
887+ description:
888+ "Columns in different equivalence classes should not be equal" ,
889+ } ,
890+ // Literal tests
891+ TestCase {
892+ left: Arc :: clone( & lit_1) ,
893+ right: Arc :: clone( & lit_1) ,
894+ expected: true ,
895+ description: "Same literal should be equal" ,
896+ } ,
897+ TestCase {
898+ left: Arc :: clone( & lit_1) ,
899+ right: Arc :: clone( & lit_2) ,
900+ expected: false ,
901+ description: "Different literals should not be equal" ,
902+ } ,
903+ // Complex expression tests
904+ TestCase {
905+ left: Arc :: new( BinaryExpr :: new(
906+ Arc :: clone( & col_a) ,
907+ Operator :: Plus ,
908+ Arc :: clone( & col_b) ,
909+ ) ) as Arc <dyn PhysicalExpr >,
910+ right: Arc :: new( BinaryExpr :: new(
911+ Arc :: clone( & col_x) ,
912+ Operator :: Plus ,
913+ Arc :: clone( & col_y) ,
914+ ) ) as Arc <dyn PhysicalExpr >,
915+ expected: true ,
916+ description:
917+ "Binary expressions with equivalent operands should be equal" ,
918+ } ,
919+ TestCase {
920+ left: Arc :: new( BinaryExpr :: new(
921+ Arc :: clone( & col_a) ,
922+ Operator :: Plus ,
923+ Arc :: clone( & col_b) ,
924+ ) ) as Arc <dyn PhysicalExpr >,
925+ right: Arc :: new( BinaryExpr :: new(
926+ Arc :: clone( & col_x) ,
927+ Operator :: Plus ,
928+ Arc :: clone( & col_a) ,
929+ ) ) as Arc <dyn PhysicalExpr >,
930+ expected: false ,
931+ description:
932+ "Binary expressions with non-equivalent operands should not be equal" ,
933+ } ,
934+ TestCase {
935+ left: Arc :: new( BinaryExpr :: new(
936+ Arc :: clone( & col_a) ,
937+ Operator :: Plus ,
938+ Arc :: clone( & lit_1) ,
939+ ) ) as Arc <dyn PhysicalExpr >,
940+ right: Arc :: new( BinaryExpr :: new(
941+ Arc :: clone( & col_x) ,
942+ Operator :: Plus ,
943+ Arc :: clone( & lit_1) ,
944+ ) ) as Arc <dyn PhysicalExpr >,
945+ expected: true ,
946+ description: "Binary expressions with equivalent column and same literal should be equal" ,
947+ } ,
948+ TestCase {
949+ left: Arc :: new( BinaryExpr :: new(
950+ Arc :: new( BinaryExpr :: new(
951+ Arc :: clone( & col_a) ,
952+ Operator :: Plus ,
953+ Arc :: clone( & col_b) ,
954+ ) ) ,
955+ Operator :: Multiply ,
956+ Arc :: clone( & lit_1) ,
957+ ) ) as Arc <dyn PhysicalExpr >,
958+ right: Arc :: new( BinaryExpr :: new(
959+ Arc :: new( BinaryExpr :: new(
960+ Arc :: clone( & col_x) ,
961+ Operator :: Plus ,
962+ Arc :: clone( & col_y) ,
963+ ) ) ,
964+ Operator :: Multiply ,
965+ Arc :: clone( & lit_1) ,
966+ ) ) as Arc <dyn PhysicalExpr >,
967+ expected: true ,
968+ description: "Nested binary expressions with equivalent operands should be equal" ,
969+ } ,
970+ ] ;
971+
972+ for TestCase {
973+ left,
974+ right,
975+ expected,
976+ description,
977+ } in test_cases
978+ {
979+ let actual = eq_group. exprs_equal ( & left, & right) ;
980+ assert_eq ! (
981+ actual, expected,
982+ "{}: Failed comparing {:?} and {:?}, expected {}, got {}" ,
983+ description, left, right, expected, actual
984+ ) ;
985+ }
986+
987+ Ok ( ( ) )
988+ }
780989}
0 commit comments