Skip to content

Commit 84bbce6

Browse files
authored
Derive UDWF equality from PartialEq, Hash (#17057)
* Simplify WindowUDFImpl::equals ForeignWindowUDF's impl `ForeignWindowUDF` contains `FFI_WindowUDF` struct which was compared by pointer only. This means that effectively `ForeignWindowUDF` was also compared by pointer only, with all other equality checks being redundant. This commit simplifies the implementation to make it more obvious and more performant. * Add PtrEq wrapper for pointer-based equality * Derive UDWF equality from PartialEq, Hash * Document moved functions
1 parent 8147565 commit 84bbce6

File tree

13 files changed

+194
-263
lines changed

13 files changed

+194
-263
lines changed

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ use datafusion::common::test_util::batches_to_string;
2828
use datafusion::common::{Result, ScalarValue};
2929
use datafusion::prelude::SessionContext;
3030
use datafusion_common::exec_datafusion_err;
31+
use datafusion_expr::ptr_eq::PtrEq;
3132
use datafusion_expr::{
32-
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
33+
udf_equals_hash, PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF,
34+
WindowUDFImpl,
3335
};
3436
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
3537
use datafusion_functions_window_common::{
@@ -523,10 +525,10 @@ impl OddCounter {
523525
}
524526

525527
fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
526-
#[derive(Debug, Clone)]
528+
#[derive(Debug, Clone, PartialEq, Hash)]
527529
struct SimpleWindowUDF {
528530
signature: Signature,
529-
test_state: Arc<TestState>,
531+
test_state: PtrEq<Arc<TestState>>,
530532
aliases: Vec<String>,
531533
}
532534

@@ -536,7 +538,7 @@ impl OddCounter {
536538
Signature::exact(vec![DataType::Float64], Volatility::Immutable);
537539
Self {
538540
signature,
539-
test_state,
541+
test_state: test_state.into(),
540542
aliases: vec!["odd_counter_alias".to_string()],
541543
}
542544
}
@@ -570,32 +572,7 @@ impl OddCounter {
570572
Ok(Field::new(field_args.name(), DataType::Int64, true).into())
571573
}
572574

573-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
574-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
575-
return false;
576-
};
577-
let Self {
578-
signature,
579-
test_state,
580-
aliases,
581-
} = self;
582-
signature == &other.signature
583-
&& Arc::ptr_eq(test_state, &other.test_state)
584-
&& aliases == &other.aliases
585-
}
586-
587-
fn hash_value(&self) -> u64 {
588-
let Self {
589-
signature,
590-
test_state,
591-
aliases,
592-
} = self;
593-
let mut hasher = DefaultHasher::new();
594-
signature.hash(&mut hasher);
595-
Arc::as_ptr(test_state).hash(&mut hasher);
596-
aliases.hash(&mut hasher);
597-
hasher.finish()
598-
}
575+
udf_equals_hash!(WindowUDFImpl);
599576
}
600577

601578
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))

datafusion/expr/src/async_udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::utils::{arc_ptr_eq, arc_ptr_hash};
18+
use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash};
1919
use crate::{
2020
udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
2121
};

datafusion/expr/src/expr_fn.rs

Lines changed: 8 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use crate::function::{
2525
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
2626
StateFieldsArgs,
2727
};
28+
use crate::ptr_eq::PtrEq;
2829
use crate::select_expr::SelectExpr;
29-
use crate::utils::{arc_ptr_eq, arc_ptr_hash};
3030
use crate::{
3131
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
3232
udf_equals_hash, AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
@@ -403,41 +403,12 @@ pub fn create_udf(
403403

404404
/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
405405
/// return type.
406+
#[derive(PartialEq, Hash)]
406407
pub struct SimpleScalarUDF {
407408
name: String,
408409
signature: Signature,
409410
return_type: DataType,
410-
fun: ScalarFunctionImplementation,
411-
}
412-
413-
impl PartialEq for SimpleScalarUDF {
414-
fn eq(&self, other: &Self) -> bool {
415-
let Self {
416-
name,
417-
signature,
418-
return_type,
419-
fun,
420-
} = self;
421-
name == &other.name
422-
&& signature == &other.signature
423-
&& return_type == &other.return_type
424-
&& arc_ptr_eq(fun, &other.fun)
425-
}
426-
}
427-
428-
impl Hash for SimpleScalarUDF {
429-
fn hash<H: Hasher>(&self, state: &mut H) {
430-
let Self {
431-
name,
432-
signature,
433-
return_type,
434-
fun,
435-
} = self;
436-
name.hash(state);
437-
signature.hash(state);
438-
return_type.hash(state);
439-
arc_ptr_hash(fun, state);
440-
}
411+
fun: PtrEq<ScalarFunctionImplementation>,
441412
}
442413

443414
impl Debug for SimpleScalarUDF {
@@ -481,7 +452,7 @@ impl SimpleScalarUDF {
481452
name: name.into(),
482453
signature,
483454
return_type,
484-
fun,
455+
fun: fun.into(),
485456
}
486457
}
487458
}
@@ -690,11 +661,12 @@ pub fn create_udwf(
690661

691662
/// Implements [`WindowUDFImpl`] for functions that have a single signature and
692663
/// return type.
664+
#[derive(PartialEq, Hash)]
693665
pub struct SimpleWindowUDF {
694666
name: String,
695667
signature: Signature,
696668
return_type: DataType,
697-
partition_evaluator_factory: PartitionEvaluatorFactory,
669+
partition_evaluator_factory: PtrEq<PartitionEvaluatorFactory>,
698670
}
699671

700672
impl Debug for SimpleWindowUDF {
@@ -724,7 +696,7 @@ impl SimpleWindowUDF {
724696
name,
725697
signature,
726698
return_type,
727-
partition_evaluator_factory,
699+
partition_evaluator_factory: partition_evaluator_factory.into(),
728700
}
729701
}
730702
}
@@ -757,40 +729,7 @@ impl WindowUDFImpl for SimpleWindowUDF {
757729
)))
758730
}
759731

760-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
761-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
762-
return false;
763-
};
764-
let Self {
765-
name,
766-
signature,
767-
return_type,
768-
partition_evaluator_factory,
769-
} = self;
770-
name == &other.name
771-
&& signature == &other.signature
772-
&& return_type == &other.return_type
773-
&& Arc::ptr_eq(
774-
partition_evaluator_factory,
775-
&other.partition_evaluator_factory,
776-
)
777-
}
778-
779-
fn hash_value(&self) -> u64 {
780-
let Self {
781-
name,
782-
signature,
783-
return_type,
784-
partition_evaluator_factory,
785-
} = self;
786-
let mut hasher = DefaultHasher::new();
787-
std::any::type_name::<Self>().hash(&mut hasher);
788-
name.hash(&mut hasher);
789-
signature.hash(&mut hasher);
790-
return_type.hash(&mut hasher);
791-
Arc::as_ptr(partition_evaluator_factory).hash(&mut hasher);
792-
hasher.finish()
793-
}
732+
udf_equals_hash!(WindowUDFImpl);
794733
}
795734

796735
pub fn interval_year_month_lit(value: &str) -> Expr {

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub mod async_udf;
6767
pub mod statistics {
6868
pub use datafusion_expr_common::statistics::*;
6969
}
70+
pub mod ptr_eq;
7071
pub mod test;
7172
pub mod tree_node;
7273
pub mod type_coercion;

datafusion/expr/src/ptr_eq.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::fmt::Debug;
19+
use std::hash::{Hash, Hasher};
20+
use std::ops::Deref;
21+
use std::sync::Arc;
22+
23+
/// Compares two `Arc` pointers for equality based on their underlying pointers values.
24+
/// This is not equivalent to [`Arc::ptr_eq`] for fat pointers, see that method
25+
/// for more information.
26+
pub fn arc_ptr_eq<T: ?Sized>(a: &Arc<T>, b: &Arc<T>) -> bool {
27+
std::ptr::eq(Arc::as_ptr(a), Arc::as_ptr(b))
28+
}
29+
30+
/// Hashes an `Arc` pointer based on its underlying pointer value.
31+
/// The general contract for this function is that if [`arc_ptr_eq`] returns `true`
32+
/// for two `Arc`s, then this function should return the same hash value for both.
33+
pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) {
34+
std::ptr::hash(Arc::as_ptr(a), hasher)
35+
}
36+
37+
/// A wrapper around a pointer that implements `PartialEq` and `Hash` comparing
38+
/// the underlying pointer address.
39+
#[derive(Clone)]
40+
#[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse.
41+
pub struct PtrEq<Ptr: PointerType>(Ptr);
42+
43+
impl<T> PartialEq for PtrEq<Arc<T>>
44+
where
45+
T: ?Sized,
46+
{
47+
fn eq(&self, other: &Self) -> bool {
48+
arc_ptr_eq(&self.0, &other.0)
49+
}
50+
}
51+
52+
impl<T> Hash for PtrEq<Arc<T>>
53+
where
54+
T: ?Sized,
55+
{
56+
fn hash<H: Hasher>(&self, state: &mut H) {
57+
arc_ptr_hash(&self.0, state);
58+
}
59+
}
60+
61+
impl<Ptr> From<Ptr> for PtrEq<Ptr>
62+
where
63+
Ptr: PointerType,
64+
{
65+
fn from(ptr: Ptr) -> Self {
66+
PtrEq(ptr)
67+
}
68+
}
69+
70+
impl<T> From<PtrEq<Arc<T>>> for Arc<T>
71+
where
72+
T: ?Sized,
73+
{
74+
fn from(wrapper: PtrEq<Arc<T>>) -> Self {
75+
wrapper.0
76+
}
77+
}
78+
79+
impl<Ptr> Debug for PtrEq<Ptr>
80+
where
81+
Ptr: PointerType + Debug,
82+
{
83+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84+
self.0.fmt(f)
85+
}
86+
}
87+
88+
impl<Ptr> Deref for PtrEq<Ptr>
89+
where
90+
Ptr: PointerType,
91+
{
92+
type Target = Ptr;
93+
94+
fn deref(&self) -> &Self::Target {
95+
&self.0
96+
}
97+
}
98+
99+
trait PointerType {}
100+
impl<T> PointerType for Arc<T> where T: ?Sized {}
101+
102+
#[cfg(test)]
103+
mod tests {
104+
use super::*;
105+
use std::hash::DefaultHasher;
106+
107+
#[test]
108+
pub fn test_ptr_eq_wrapper() {
109+
let a = Arc::new("Hello".to_string());
110+
let b = Arc::new(a.deref().clone());
111+
let c = Arc::new("world".to_string());
112+
113+
let wrapper = PtrEq(Arc::clone(&a));
114+
assert_eq!(wrapper, wrapper);
115+
116+
// same address (equal)
117+
assert_eq!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&a)));
118+
assert_eq!(hash(PtrEq(Arc::clone(&a))), hash(PtrEq(Arc::clone(&a))));
119+
120+
// different address, same content (not equal)
121+
assert_ne!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&b)));
122+
123+
// different address, different content (not equal)
124+
assert_ne!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&c)));
125+
}
126+
127+
fn hash<T: Hash>(value: T) -> u64 {
128+
let hasher = &mut DefaultHasher::new();
129+
value.hash(hasher);
130+
hasher.finish()
131+
}
132+
}

0 commit comments

Comments
 (0)