Skip to content

Commit b78c517

Browse files
authored
Update ScalarUDF equality
compare ScalarUDFs by pointer equality update projection pushdown tests to share UDF instances add unit test covering new equality semantics
1 parent f302383 commit b78c517

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,7 @@ mod tests {
13891389

13901390
#[test]
13911391
fn test_update_matching_exprs() -> Result<()> {
1392+
let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new()));
13921393
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
13931394
Arc::new(BinaryExpr::new(
13941395
Arc::new(Column::new("a", 3)),
@@ -1403,7 +1404,7 @@ mod tests {
14031404
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
14041405
Arc::new(ScalarFunctionExpr::new(
14051406
"scalar_expr",
1406-
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
1407+
Arc::clone(&udf),
14071408
vec![
14081409
Arc::new(BinaryExpr::new(
14091410
Arc::new(Column::new("b", 1)),
@@ -1468,7 +1469,7 @@ mod tests {
14681469
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
14691470
Arc::new(ScalarFunctionExpr::new(
14701471
"scalar_expr",
1471-
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
1472+
Arc::clone(&udf),
14721473
vec![
14731474
Arc::new(BinaryExpr::new(
14741475
Arc::new(Column::new("b", 1)),
@@ -1522,6 +1523,7 @@ mod tests {
15221523

15231524
#[test]
15241525
fn test_update_projected_exprs() -> Result<()> {
1526+
let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new()));
15251527
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
15261528
Arc::new(BinaryExpr::new(
15271529
Arc::new(Column::new("a", 3)),
@@ -1536,7 +1538,7 @@ mod tests {
15361538
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
15371539
Arc::new(ScalarFunctionExpr::new(
15381540
"scalar_expr",
1539-
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
1541+
Arc::clone(&udf),
15401542
vec![
15411543
Arc::new(BinaryExpr::new(
15421544
Arc::new(Column::new("b", 1)),
@@ -1601,7 +1603,7 @@ mod tests {
16011603
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
16021604
Arc::new(ScalarFunctionExpr::new(
16031605
"scalar_expr",
1604-
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
1606+
Arc::clone(&udf),
16051607
vec![
16061608
Arc::new(BinaryExpr::new(
16071609
Arc::new(Column::new("b_new", 1)),

datafusion/expr/src/expr.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,51 @@ mod test {
25032503
assert_eq!(udf.signature().volatility, Volatility::Volatile);
25042504
}
25052505

2506+
#[test]
2507+
fn test_scalar_udf_eq_pointer() {
2508+
#[derive(Debug)]
2509+
struct DummyUDF {
2510+
signature: Signature,
2511+
}
2512+
2513+
impl DummyUDF {
2514+
fn new() -> Self {
2515+
Self {
2516+
signature: Signature::variadic_any(Volatility::Immutable),
2517+
}
2518+
}
2519+
}
2520+
2521+
impl ScalarUDFImpl for DummyUDF {
2522+
fn as_any(&self) -> &dyn Any {
2523+
self
2524+
}
2525+
2526+
fn name(&self) -> &str {
2527+
"dummy_udf"
2528+
}
2529+
2530+
fn signature(&self) -> &Signature {
2531+
&self.signature
2532+
}
2533+
2534+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
2535+
Ok(DataType::Int32)
2536+
}
2537+
2538+
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
2539+
unimplemented!("DummyUDF::invoke")
2540+
}
2541+
}
2542+
2543+
let udf1 = ScalarUDF::new_from_impl(DummyUDF::new());
2544+
let udf1_clone = udf1.clone();
2545+
let udf2 = ScalarUDF::new_from_impl(DummyUDF::new());
2546+
2547+
assert!(udf1.eq(&udf1_clone));
2548+
assert!(!udf1.eq(&udf2));
2549+
}
2550+
25062551
use super::*;
25072552

25082553
#[test]

datafusion/expr/src/udf.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub struct ScalarUDF {
6161

6262
impl PartialEq for ScalarUDF {
6363
fn eq(&self, other: &Self) -> bool {
64-
self.inner.equals(other.inner.as_ref())
64+
Arc::ptr_eq(&self.inner, &other.inner)
6565
}
6666
}
6767

@@ -678,9 +678,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
678678
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
679679
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
680680
///
681-
/// By default, compares [`Self::name`] and [`Self::signature`].
681+
/// By default, checks for pointer equality.
682682
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
683-
self.name() == other.name() && self.signature() == other.signature()
683+
let self_ptr = self as *const _ as *const ();
684+
let other_ptr = other as *const _ as *const ();
685+
std::ptr::eq(self_ptr, other_ptr)
684686
}
685687

686688
/// Returns a hash value for this scalar UDF.

0 commit comments

Comments
 (0)