Skip to content

Commit 8494a39

Browse files
findepialamb
andauthored
Derive WindowUDFImpl equality, hash from Eq, Hash traits (#17081)
* Move `DynEq`, `DynHash` to `expr-common` Use `DynEq` and `DynHash` traits from physical expressions crate to a common crate for physical and logical expressions. This allows them to be used by logical expressions. * Derive WindowUDFImpl equality, hash from Eq, Hash traits Previously, the `WindowUDFImpl` trait contained `equals` and `hash_value` methods with contracts following the `Eq` and `Hash` traits. However, the existence of default implementations of these methods made it error-prone, with many functions (scalar, aggregate, window) missing to customize the equals even though they ought to. There is no fix to this that's not an API breaking change, so a breaking change is warranted. Removing the default implementations would be enough of a solution, but at the cost of a lot of boilerplate needed in implementations. Instead, this removes the methods from the trait, and reuses `DynEq`, `DynHash` traits used previously only for physical expressions. This allows for functions to provide their implementations using no more than `#[derive(PartialEq, Eq, Hash)]` in a typical case. * upgrade guide * fix typo * Fix PartialEq for WindowUDF impl Wrong Any was compared * Seal DynEq, DynHash * Link to epic in upgrade guide Co-authored-by: Andrew Lamb <[email protected]> --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 28e042d commit 8494a39

File tree

20 files changed

+160
-142
lines changed

20 files changed

+160
-142
lines changed

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use datafusion::prelude::*;
4343
/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance.
4444
///
4545
/// To do so, we must implement the `WindowUDFImpl` trait.
46-
#[derive(Debug, Clone)]
46+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4747
struct SmoothItUdf {
4848
signature: Signature,
4949
}
@@ -149,7 +149,7 @@ impl PartitionEvaluator for MyPartitionEvaluator {
149149
}
150150

151151
/// This UDWF will show how to use the WindowUDFImpl::simplify() API
152-
#[derive(Debug, Clone)]
152+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
153153
struct SimplifySmoothItUdf {
154154
signature: Signature,
155155
}

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ use datafusion::prelude::SessionContext;
3030
use datafusion_common::exec_datafusion_err;
3131
use datafusion_expr::ptr_eq::PtrEq;
3232
use datafusion_expr::{
33-
udf_equals_hash, PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF,
34-
WindowUDFImpl,
33+
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
3534
};
3635
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
3736
use datafusion_functions_window_common::{
@@ -42,7 +41,7 @@ use datafusion_physical_expr::{
4241
PhysicalExpr,
4342
};
4443
use std::collections::HashMap;
45-
use std::hash::{DefaultHasher, Hash, Hasher};
44+
use std::hash::{Hash, Hasher};
4645
use std::{
4746
any::Any,
4847
ops::Range,
@@ -571,8 +570,6 @@ impl OddCounter {
571570
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
572571
Ok(Field::new(field_args.name(), DataType::Int64, true).into())
573572
}
574-
575-
udf_equals_hash!(WindowUDFImpl);
576573
}
577574

578575
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
@@ -648,7 +645,7 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef {
648645
Arc::new(array)
649646
}
650647

651-
#[derive(Debug)]
648+
#[derive(Debug, PartialEq, Eq, Hash)]
652649
struct VariadicWindowUDF {
653650
signature: Signature,
654651
}
@@ -770,6 +767,31 @@ struct MetadataBasedWindowUdf {
770767
metadata: HashMap<String, String>,
771768
}
772769

770+
impl PartialEq for MetadataBasedWindowUdf {
771+
fn eq(&self, other: &Self) -> bool {
772+
let Self {
773+
name,
774+
signature,
775+
metadata,
776+
} = self;
777+
name == &other.name
778+
&& signature == &other.signature
779+
&& metadata == &other.metadata
780+
}
781+
}
782+
impl Eq for MetadataBasedWindowUdf {}
783+
impl Hash for MetadataBasedWindowUdf {
784+
fn hash<H: Hasher>(&self, state: &mut H) {
785+
let Self {
786+
name,
787+
signature,
788+
metadata: _, // unhashable
789+
} = self;
790+
name.hash(state);
791+
signature.hash(state);
792+
}
793+
}
794+
773795
impl MetadataBasedWindowUdf {
774796
fn new(metadata: HashMap<String, String>) -> Self {
775797
// The name we return must be unique. Otherwise we will not call distinct
@@ -820,33 +842,6 @@ impl WindowUDFImpl for MetadataBasedWindowUdf {
820842
.with_metadata(self.metadata.clone())
821843
.into())
822844
}
823-
824-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
825-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
826-
return false;
827-
};
828-
let Self {
829-
name,
830-
signature,
831-
metadata,
832-
} = self;
833-
name == &other.name
834-
&& signature == &other.signature
835-
&& metadata == &other.metadata
836-
}
837-
838-
fn hash_value(&self) -> u64 {
839-
let Self {
840-
name,
841-
signature,
842-
metadata: _, // unhashable
843-
} = self;
844-
let mut hasher = DefaultHasher::new();
845-
std::any::type_name::<Self>().hash(&mut hasher);
846-
name.hash(&mut hasher);
847-
signature.hash(&mut hasher);
848-
hasher.finish()
849-
}
850845
}
851846

852847
#[derive(Debug)]

datafusion/expr-common/src/dyn_eq.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::hash::{Hash, Hasher};
20+
21+
/// A dyn-compatible version of [`Eq`] trait.
22+
/// The implementation constraints for this trait are the same as for [`Eq`]:
23+
/// the implementation must be reflexive, symmetric, and transitive.
24+
/// Additionally, if two values can be compared with [`DynEq`] and [`PartialEq`] then
25+
/// they must be [`DynEq`]-equal if and only if they are [`PartialEq`]-equal.
26+
/// It is therefore strongly discouraged to implement this trait for types
27+
/// that implement `PartialEq<Other>` or `Eq<Other>` for any type `Other` other than `Self`.
28+
///
29+
/// Note: This trait should not be implemented directly. Implement `Eq` and `Any` and use
30+
/// the blanket implementation.
31+
#[allow(private_bounds)]
32+
pub trait DynEq: private::EqSealed {
33+
fn dyn_eq(&self, other: &dyn Any) -> bool;
34+
}
35+
36+
impl<T: Eq + Any> private::EqSealed for T {}
37+
impl<T: Eq + Any> DynEq for T {
38+
fn dyn_eq(&self, other: &dyn Any) -> bool {
39+
other.downcast_ref::<Self>() == Some(self)
40+
}
41+
}
42+
43+
/// A dyn-compatible version of [`Hash`] trait.
44+
/// If two values are equal according to [`DynEq`], they must produce the same hash value.
45+
///
46+
/// Note: This trait should not be implemented directly. Implement `Hash` and `Any` and use
47+
/// the blanket implementation.
48+
#[allow(private_bounds)]
49+
pub trait DynHash: private::HashSealed {
50+
fn dyn_hash(&self, _state: &mut dyn Hasher);
51+
}
52+
53+
impl<T: Hash + Any> private::HashSealed for T {}
54+
impl<T: Hash + Any> DynHash for T {
55+
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
56+
self.type_id().hash(&mut state);
57+
self.hash(&mut state)
58+
}
59+
}
60+
61+
mod private {
62+
pub(super) trait EqSealed {}
63+
pub(super) trait HashSealed {}
64+
}

datafusion/expr-common/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
pub mod accumulator;
3636
pub mod casts;
3737
pub mod columnar_value;
38+
pub mod dyn_eq;
3839
pub mod groups_accumulator;
3940
pub mod interval_arithmetic;
4041
pub mod operator;

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,6 @@ impl WindowUDFImpl for SimpleWindowUDF {
695695
true,
696696
)))
697697
}
698-
699-
udf_equals_hash!(WindowUDFImpl);
700698
}
701699

702700
pub fn interval_year_month_lit(value: &str) -> Expr {

datafusion/expr/src/udf_eq.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl};
1919
use std::fmt::Debug;
20-
use std::hash::{Hash, Hasher};
20+
use std::hash::{DefaultHasher, Hash, Hasher};
2121
use std::ops::Deref;
2222
use std::sync::Arc;
2323

@@ -97,7 +97,18 @@ macro_rules! impl_for_udf_eq {
9797

9898
impl_for_udf_eq!(dyn AggregateUDFImpl + '_);
9999
impl_for_udf_eq!(dyn ScalarUDFImpl + '_);
100-
impl_for_udf_eq!(dyn WindowUDFImpl + '_);
100+
101+
impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
102+
fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
103+
self.as_ref().dyn_eq(other.as_any())
104+
}
105+
106+
fn hash_value(&self) -> u64 {
107+
let hasher = &mut DefaultHasher::new();
108+
self.as_ref().dyn_hash(hasher);
109+
hasher.finish()
110+
}
111+
}
101112

102113
#[cfg(test)]
103114
mod tests {

datafusion/expr/src/udwf.rs

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use arrow::compute::SortOptions;
2121
use std::cmp::Ordering;
22-
use std::hash::{DefaultHasher, Hash, Hasher};
22+
use std::hash::{Hash, Hasher};
2323
use std::{
2424
any::Any,
2525
fmt::{self, Debug, Display, Formatter},
@@ -31,11 +31,11 @@ use arrow::datatypes::{DataType, FieldRef};
3131
use crate::expr::WindowFunction;
3232
use crate::udf_eq::UdfEq;
3333
use crate::{
34-
function::WindowFunctionSimplification, udf_equals_hash, Expr, PartitionEvaluator,
35-
Signature,
34+
function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature,
3635
};
3736
use datafusion_common::{not_impl_err, Result};
3837
use datafusion_doc::Documentation;
38+
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
3939
use datafusion_functions_window_common::expr::ExpressionArgs;
4040
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
4141
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
@@ -82,15 +82,15 @@ impl Display for WindowUDF {
8282

8383
impl PartialEq for WindowUDF {
8484
fn eq(&self, other: &Self) -> bool {
85-
self.inner.equals(other.inner.as_ref())
85+
self.inner.dyn_eq(other.inner.as_any())
8686
}
8787
}
8888

8989
impl Eq for WindowUDF {}
9090

9191
impl Hash for WindowUDF {
9292
fn hash<H: Hasher>(&self, state: &mut H) {
93-
self.inner.hash_value().hash(state)
93+
self.inner.dyn_hash(state)
9494
}
9595
}
9696

@@ -229,6 +229,10 @@ where
229229
/// This trait exposes the full API for implementing user defined window functions and
230230
/// can be used to implement any function.
231231
///
232+
/// While the trait depends on [`DynEq`] and [`DynHash`] traits, these should not be
233+
/// implemented directly. Instead, implement [`Eq`] and [`Hash`] and leverage the
234+
/// blanket implementations of [`DynEq`] and [`DynHash`].
235+
///
232236
/// See [`advanced_udwf.rs`] for a full example with complete implementation and
233237
/// [`WindowUDF`] for other available options.
234238
///
@@ -246,7 +250,7 @@ where
246250
/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
247251
/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
248252
///
249-
/// #[derive(Debug, Clone)]
253+
/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
250254
/// struct SmoothIt {
251255
/// signature: Signature,
252256
/// }
@@ -305,7 +309,7 @@ where
305309
/// .build()
306310
/// .unwrap();
307311
/// ```
308-
pub trait WindowUDFImpl: Debug + Send + Sync {
312+
pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync {
309313
/// Returns this object as an [`Any`] trait object
310314
fn as_any(&self) -> &dyn Any;
311315

@@ -358,41 +362,6 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
358362
None
359363
}
360364

361-
/// Return true if this window UDF is equal to the other.
362-
///
363-
/// Allows customizing the equality of window UDFs.
364-
/// *Must* be implemented explicitly if the UDF type has internal state.
365-
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
366-
///
367-
/// - reflexive: `a.equals(a)`;
368-
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
369-
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
370-
///
371-
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
372-
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
373-
self.as_any().type_id() == other.as_any().type_id()
374-
&& self.name() == other.name()
375-
&& self.aliases() == other.aliases()
376-
&& self.signature() == other.signature()
377-
}
378-
379-
/// Returns a hash value for this window UDF.
380-
///
381-
/// Allows customizing the hash code of window UDFs.
382-
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
383-
///
384-
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
385-
/// their `hash_value`s must be the same.
386-
///
387-
/// By default, it only hashes the type. The other fields are not hashed, as usually the
388-
/// name, signature, and aliases are implied by the UDF type. Recall that UDFs with state
389-
/// (and thus possibly changing fields) must override [`Self::equals`] and [`Self::hash_value`].
390-
fn hash_value(&self) -> u64 {
391-
let hasher = &mut DefaultHasher::new();
392-
self.as_any().type_id().hash(hasher);
393-
hasher.finish()
394-
}
395-
396365
/// The [`FieldRef`] of the final result of evaluating this window function.
397366
///
398367
/// Call `field_args.name()` to get the fully qualified name for defining
@@ -461,7 +430,7 @@ pub enum ReversedUDWF {
461430

462431
impl PartialEq for dyn WindowUDFImpl {
463432
fn eq(&self, other: &Self) -> bool {
464-
self.equals(other)
433+
self.dyn_eq(other.as_any())
465434
}
466435
}
467436

@@ -533,8 +502,6 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
533502
self.inner.simplify()
534503
}
535504

536-
udf_equals_hash!(WindowUDFImpl);
537-
538505
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
539506
self.inner.field(field_args)
540507
}
@@ -598,7 +565,7 @@ mod test {
598565
use std::any::Any;
599566
use std::cmp::Ordering;
600567

601-
#[derive(Debug, Clone)]
568+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
602569
struct AWindowUDF {
603570
signature: Signature,
604571
}
@@ -637,7 +604,7 @@ mod test {
637604
}
638605
}
639606

640-
#[derive(Debug, Clone)]
607+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
641608
struct BWindowUDF {
642609
signature: Signature,
643610
}
@@ -676,6 +643,15 @@ mod test {
676643
}
677644
}
678645

646+
#[test]
647+
fn test_partial_eq() {
648+
let a1 = WindowUDF::from(AWindowUDF::new());
649+
let a2 = WindowUDF::from(AWindowUDF::new());
650+
let eq = a1 == a2;
651+
assert!(eq);
652+
assert_eq!(a1, a2);
653+
}
654+
679655
#[test]
680656
fn test_partial_ord() {
681657
let a1 = WindowUDF::from(AWindowUDF::new());

0 commit comments

Comments
 (0)