diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 4f681896dfc66..e6b61e2b8f15f 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1243,7 +1243,6 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] - #[expect(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 98dd1f235aee7..1342face3998b 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -276,6 +276,7 @@ fn hash_array( /// HAS_NULLS: do we have to check null in the inner loop /// HAS_BUFFERS: if true, array has external buffers; if false, all strings are inlined/ less then 12 bytes /// REHASH: if true, combining with existing hash, otherwise initializing +#[cfg(not(feature = "force_hash_collisions"))] #[inline(never)] fn hash_string_view_array_inner< T: ByteViewType, diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index e4e048ad3c0d8..4694ce077ba52 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -568,6 +568,10 @@ impl PartialEq for ScalarValue { impl PartialOrd for ScalarValue { fn partial_cmp(&self, other: &Self) -> Option { use ScalarValue::*; + + if self.is_null() || other.is_null() { + return None; + } // This purposely doesn't have a catch-all "(_, _)" so that // any newly added enum variant will require editing this list // or else face a compile error @@ -723,8 +727,7 @@ impl PartialOrd for ScalarValue { if k1 == k2 { v1.partial_cmp(v2) } else { None } } (Dictionary(_, _), _) => None, - (Null, Null) => Some(Ordering::Equal), - (Null, _) => None, + _ => None, } } } @@ -4017,10 +4020,14 @@ impl ScalarValue { arr1 == &right } - /// Compare `self` with `other` and return an `Ordering`. + /// Compare two `ScalarValue`s. /// - /// This is the same as [`PartialOrd`] except that it returns - /// `Err` if the values cannot be compared, e.g., they have incompatible data types. + /// Returns an error if: + /// * the values are of incompatible types, or + /// * either value is NULL. + /// + /// This differs from `partial_cmp`, which returns `None` for NULL inputs + /// instead of an error. pub fn try_cmp(&self, other: &Self) -> Result { self.partial_cmp(other).ok_or_else(|| { _internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}") @@ -5760,10 +5767,9 @@ mod tests { .unwrap(), Ordering::Less ); - assert_eq!( + assert!( ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2))) - .unwrap(), - Ordering::Less + .is_err() ); assert_starts_with( ScalarValue::try_cmp( @@ -9348,4 +9354,20 @@ mod tests { ] ); } + #[test] + fn scalar_partial_ordering_nulls() { + use ScalarValue::*; + + assert_eq!(Int32(Some(3)).partial_cmp(&Int32(None)), None); + + assert_eq!(Int32(None).partial_cmp(&Int32(Some(3))), None); + + assert_eq!(Int32(None).partial_cmp(&Int32(None)), None); + + assert_eq!(Null.partial_cmp(&Int32(Some(3))), None); + + assert_eq!(Int32(Some(3)).partial_cmp(&Null), None); + + assert_eq!(Null.partial_cmp(&Null), None); + } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 03310a7bde193..51de9a76cfd0a 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -1021,27 +1021,33 @@ mod tests { fn vector_ord() { assert!(vec![1, 0, 0, 0, 0, 0, 0, 1] < vec![1, 0, 0, 0, 0, 0, 0, 2]); assert!(vec![1, 0, 0, 0, 0, 0, 1, 1] > vec![1, 0, 0, 0, 0, 0, 0, 2]); - assert!( + // Vectors containing Null values cannot be compared because + // ScalarValue::partial_cmp returns None for null comparisons + assert_eq!( vec![ ScalarValue::Int32(Some(2)), Null, ScalarValue::Int32(Some(0)), - ] < vec![ + ] + .partial_cmp(&vec![ ScalarValue::Int32(Some(2)), Null, ScalarValue::Int32(Some(1)), - ] + ]), + None ); - assert!( + assert_eq!( vec![ ScalarValue::Int32(Some(2)), ScalarValue::Int32(None), ScalarValue::Int32(Some(0)), - ] < vec![ + ] + .partial_cmp(&vec![ ScalarValue::Int32(Some(2)), ScalarValue::Int32(None), ScalarValue::Int32(Some(1)), - ] + ]), + None ); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index ec22be525464b..38e256eb1739a 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -24,7 +24,7 @@ use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{Column, Result}; +use datafusion_common::{Column, Result, TableReference}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output /// For example, `max(x)` is written to `col("max(x)")` @@ -107,7 +107,7 @@ fn rewrite_in_terms_of_projection( for proj_expr in proj_exprs { proj_expr.apply(|e| { if expr_match(&search_col, e) { - found = Some(e.clone()); + found = Some(proj_expr.clone()); return Ok(TreeNodeRecursion::Stop); } Ok(TreeNodeRecursion::Continue) @@ -115,16 +115,56 @@ fn rewrite_in_terms_of_projection( } if let Some(found) = found { + // Determine what to return based on the original expression type + let result_expr = if let Expr::Column(original_col) = &expr { + // For plain columns, preserve the original qualification status + Expr::Column(Column::new( + original_col.relation.clone(), + search_col.try_as_col().unwrap().name.clone(), + )) + } else { + // For other expressions (aggregates, etc.), return a column reference + // to the projection output, unless it's wrapped in a cast + match &normalized_expr { + Expr::Cast(_) | Expr::TryCast(_) => { + // For casts, use the projection expression to preserve aliases + found + } + _ => { + // For aggregates and other expressions, create a column reference + // Split the column name at the last dot to handle legacy qualified names + let col_name = search_col.try_as_col().unwrap().name.as_str(); + let col_ref = if let Some((relation, field_name)) = + col_name.rsplit_once('.') + { + Expr::Column(Column::new( + Some(TableReference::bare(relation)), + field_name, + )) + } else { + search_col + }; + + // If the projection expression has an alias, preserve it + if let Expr::Alias(Alias { name, .. }) = &found { + col_ref.alias(name.clone()) + } else { + col_ref + } + } + } + }; + return Ok(Transformed::yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { - expr: Box::new(found), + expr: Box::new(result_expr), data_type, }), Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { - expr: Box::new(found), + expr: Box::new(result_expr), data_type, }), - _ => found, + _ => result_expr, })); } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4219c24bfc9c9..e8ace62330642 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4761,7 +4761,7 @@ mod tests { let f2 = count_window_function(schema_without_metadata()); assert_eq!(f, f2); assert_eq!(hash(&f), hash(&f2)); - assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + assert_eq!(f.partial_cmp(&f2), None); // Same like `f`, except for schema metadata let o = count_window_function(schema_with_metadata());