Skip to content

Commit b8bf7c5

Browse files
authored
Derive AggregateUDFImpl equality, hash from Eq, Hash traits (#17130)
Follows similar change for `WindowUDFImpl`, i.e. the 8494a39 commit. Previously, the `AggregateUDFImpl` trait contained `equals` and `hash_value` methods with contracts following the `Eq` and `Hash` traits. However, the existence of default implementations of these methods made it error-prone, with many functions (scalar, aggregate, window) missing to customize the equals even though they ought to. There is no fix to this that's not an API breaking change, so a breaking change is warranted. Removing the default implementations would be enough of a solution, but at the cost of a lot of boilerplate needed in implementations. Instead, this removes the methods from the trait, and reuses `DynEq`, `DynHash` traits used previously only for physical expressions. This allows for functions to provide their implementations using no more than `#[derive(PartialEq, Eq, Hash)]` in a typical case.
1 parent 25ad99d commit b8bf7c5

File tree

34 files changed

+79
-110
lines changed

34 files changed

+79
-110
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use datafusion::prelude::*;
4141
/// a function `accumulator` that returns the `Accumulator` instance.
4242
///
4343
/// To do so, we must implement the `AggregateUDFImpl` trait.
44-
#[derive(Debug, Clone)]
44+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4545
struct GeoMeanUdaf {
4646
signature: Signature,
4747
}
@@ -368,7 +368,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
368368

369369
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
370370
/// defined aggregate function with a different expression which is defined in the `simplify` method.
371-
#[derive(Debug, Clone)]
371+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
372372
struct SimplifiedGeoMeanUdaf {
373373
signature: Signature,
374374
}

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ use datafusion_common::{assert_contains, exec_datafusion_err};
5555
use datafusion_common::{cast::as_primitive_array, exec_err};
5656
use datafusion_expr::expr::WindowFunction;
5757
use datafusion_expr::{
58-
col, create_udaf, function::AccumulatorArgs, udf_equals_hash, AggregateUDFImpl, Expr,
58+
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr,
5959
GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition,
6060
};
6161
use datafusion_functions_aggregate::average::AvgAccumulator;
@@ -816,8 +816,6 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
816816
) -> Result<Box<dyn GroupsAccumulator>> {
817817
Ok(Box::new(self.clone()))
818818
}
819-
820-
udf_equals_hash!(AggregateUDFImpl);
821819
}
822820

823821
impl Accumulator for TestGroupsAccumulator {
@@ -970,8 +968,6 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
970968
curr_sum: 0,
971969
}))
972970
}
973-
974-
udf_equals_hash!(AggregateUDFImpl);
975971
}
976972

977973
#[derive(Debug)]

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,6 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
600600
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
601601
Ok(self.state_fields.clone())
602602
}
603-
604-
udf_equals_hash!(AggregateUDFImpl);
605603
}
606604

607605
/// Creates a new UDWF with a specific signature, state type and return type.

datafusion/expr/src/test/function_stub.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub fn avg(expr: Expr) -> Expr {
9292
}
9393

9494
/// Stub `sum` used for optimizer testing
95-
#[derive(Debug)]
95+
#[derive(Debug, PartialEq, Eq, Hash)]
9696
pub struct Sum {
9797
signature: Signature,
9898
}
@@ -200,6 +200,7 @@ impl AggregateUDFImpl for Sum {
200200
}
201201

202202
/// Testing stub implementation of COUNT aggregate
203+
#[derive(PartialEq, Eq, Hash)]
203204
pub struct Count {
204205
signature: Signature,
205206
aliases: Vec<String>,
@@ -288,6 +289,7 @@ pub fn min(expr: Expr) -> Expr {
288289
}
289290

290291
/// Testing stub implementation of Min aggregate
292+
#[derive(PartialEq, Eq, Hash)]
291293
pub struct Min {
292294
signature: Signature,
293295
}
@@ -369,6 +371,7 @@ pub fn max(expr: Expr) -> Expr {
369371
}
370372

371373
/// Testing stub implementation of MAX aggregate
374+
#[derive(PartialEq, Eq, Hash)]
372375
pub struct Max {
373376
signature: Signature,
374377
}
@@ -437,7 +440,7 @@ impl AggregateUDFImpl for Max {
437440
}
438441

439442
/// Testing stub implementation of avg aggregate
440-
#[derive(Debug)]
443+
#[derive(Debug, PartialEq, Eq, Hash)]
441444
pub struct Avg {
442445
signature: Signature,
443446
aliases: Vec<String>,

datafusion/expr/src/udaf.rs

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
use std::any::Any;
2121
use std::cmp::Ordering;
2222
use std::fmt::{self, Debug, Formatter, Write};
23-
use std::hash::{DefaultHasher, Hash, Hasher};
23+
use std::hash::{Hash, Hasher};
2424
use std::sync::Arc;
2525
use std::vec;
2626

2727
use arrow::datatypes::{DataType, Field, FieldRef};
2828

2929
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
30+
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
3031
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
3132

3233
use crate::expr::{
@@ -41,7 +42,7 @@ use crate::groups_accumulator::GroupsAccumulator;
4142
use crate::udf_eq::UdfEq;
4243
use crate::utils::format_state_name;
4344
use crate::utils::AggregateOrderSensitivity;
44-
use crate::{expr_vec_fmt, udf_equals_hash, Accumulator, Expr};
45+
use crate::{expr_vec_fmt, Accumulator, Expr};
4546
use crate::{Documentation, Signature};
4647

4748
/// Logical representation of a user-defined [aggregate function] (UDAF).
@@ -82,15 +83,15 @@ pub struct AggregateUDF {
8283

8384
impl PartialEq for AggregateUDF {
8485
fn eq(&self, other: &Self) -> bool {
85-
self.inner.equals(other.inner.as_ref())
86+
self.inner.dyn_eq(other.inner.as_any())
8687
}
8788
}
8889

8990
impl Eq for AggregateUDF {}
9091

9192
impl Hash for AggregateUDF {
9293
fn hash<H: Hasher>(&self, state: &mut H) {
93-
self.inner.hash_value().hash(state)
94+
self.inner.dyn_hash(state)
9495
}
9596
}
9697

@@ -373,7 +374,7 @@ where
373374
/// # use arrow::datatypes::Schema;
374375
/// # use arrow::datatypes::Field;
375376
///
376-
/// #[derive(Debug, Clone)]
377+
/// #[derive(Debug, Clone, PartialEq, Eq, Hash)]
377378
/// struct GeoMeanUdf {
378379
/// signature: Signature,
379380
/// }
@@ -426,7 +427,7 @@ where
426427
/// // Call the function `geo_mean(col)`
427428
/// let expr = geometric_mean.call(vec![col("a")]);
428429
/// ```
429-
pub trait AggregateUDFImpl: Debug + Send + Sync {
430+
pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
430431
/// Returns this object as an [`Any`] trait object
431432
fn as_any(&self) -> &dyn Any;
432433

@@ -914,41 +915,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
914915
not_impl_err!("Function {} does not implement coerce_types", self.name())
915916
}
916917

917-
/// Return true if this aggregate UDF is equal to the other.
918-
///
919-
/// Allows customizing the equality of aggregate UDFs.
920-
/// *Must* be implemented explicitly if the UDF type has internal state.
921-
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
922-
///
923-
/// - reflexive: `a.equals(a)`;
924-
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
925-
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
926-
///
927-
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
928-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
929-
self.as_any().type_id() == other.as_any().type_id()
930-
&& self.name() == other.name()
931-
&& self.aliases() == other.aliases()
932-
&& self.signature() == other.signature()
933-
}
934-
935-
/// Returns a hash value for this aggregate UDF.
936-
///
937-
/// Allows customizing the hash code of aggregate UDFs.
938-
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
939-
///
940-
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
941-
/// their `hash_value`s must be the same.
942-
///
943-
/// By default, it only hashes the type. The other fields are not hashed, as usually the
944-
/// name, signature, and aliases are implied by the UDF type. Recall that UDFs with state
945-
/// (and thus possibly changing fields) must override [`Self::equals`] and [`Self::hash_value`].
946-
fn hash_value(&self) -> u64 {
947-
let hasher = &mut DefaultHasher::new();
948-
self.as_any().type_id().hash(hasher);
949-
hasher.finish()
950-
}
951-
952918
/// If this function is max, return true
953919
/// If the function is min, return false
954920
/// Otherwise return None (the default)
@@ -1008,10 +974,11 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
1008974

1009975
impl PartialEq for dyn AggregateUDFImpl {
1010976
fn eq(&self, other: &Self) -> bool {
1011-
self.equals(other)
977+
self.dyn_eq(other.as_any())
1012978
}
1013979
}
1014980

981+
// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `dyn AggregateUDFImpl` and it should be
1015982
// Manual implementation of `PartialOrd`
1016983
// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
1017984
// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
@@ -1194,8 +1161,6 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
11941161
self.inner.set_monotonicity(data_type)
11951162
}
11961163

1197-
udf_equals_hash!(AggregateUDFImpl);
1198-
11991164
fn documentation(&self) -> Option<&Documentation> {
12001165
self.inner.documentation()
12011166
}
@@ -1266,7 +1231,7 @@ mod test {
12661231
use std::any::Any;
12671232
use std::cmp::Ordering;
12681233

1269-
#[derive(Debug, Clone)]
1234+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12701235
struct AMeanUdf {
12711236
signature: Signature,
12721237
}
@@ -1307,7 +1272,7 @@ mod test {
13071272
}
13081273
}
13091274

1310-
#[derive(Debug, Clone)]
1275+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13111276
struct BMeanUdf {
13121277
signature: Signature,
13131278
}
@@ -1347,6 +1312,15 @@ mod test {
13471312
}
13481313
}
13491314

1315+
#[test]
1316+
fn test_partial_eq() {
1317+
let a1 = AggregateUDF::from(AMeanUdf::new());
1318+
let a2 = AggregateUDF::from(AMeanUdf::new());
1319+
let eq = a1 == a2;
1320+
assert!(eq);
1321+
assert_eq!(a1, a2);
1322+
}
1323+
13501324
#[test]
13511325
fn test_partial_ord() {
13521326
// Test validates that partial ord is defined for AggregateUDF using the name and signature,

datafusion/expr/src/udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ impl PartialEq for ScalarUDF {
6666
}
6767
}
6868

69+
// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `ScalarUDF` and it should be
6970
// Manual implementation based on `ScalarUDFImpl::equals`
7071
impl PartialOrd for ScalarUDF {
7172
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {

datafusion/expr/src/udf_eq.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,20 @@ macro_rules! impl_for_udf_eq {
9595
};
9696
}
9797

98-
impl_for_udf_eq!(dyn AggregateUDFImpl + '_);
9998
impl_for_udf_eq!(dyn ScalarUDFImpl + '_);
10099

100+
impl UdfPointer for Arc<dyn AggregateUDFImpl + '_> {
101+
fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool {
102+
self.as_ref().dyn_eq(other.as_any())
103+
}
104+
105+
fn hash_value(&self) -> u64 {
106+
let hasher = &mut DefaultHasher::new();
107+
self.as_ref().dyn_hash(hasher);
108+
hasher.finish()
109+
}
110+
}
111+
101112
impl UdfPointer for Arc<dyn WindowUDFImpl + '_> {
102113
fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool {
103114
self.as_ref().dyn_eq(other.as_any())

datafusion/expr/src/udwf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ impl PartialEq for dyn WindowUDFImpl {
434434
}
435435
}
436436

437+
// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `dyn WindowUDFImpl` and it should be
437438
impl PartialOrd for dyn WindowUDFImpl {
438439
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
439440
match self.name().partial_cmp(other.name()) {

datafusion/ffi/src/udaf/mod.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ use crate::{
4949
util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped},
5050
volatility::FFI_Volatility,
5151
};
52-
use datafusion::logical_expr::udf_equals_hash;
5352
use prost::{DecodeError, Message};
5453

5554
mod accumulator;
@@ -567,8 +566,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
567566
Ok(rvec_wrapped_to_vec_datatype(&result_types)?)
568567
}
569568
}
570-
571-
udf_equals_hash!(AggregateUDFImpl);
572569
}
573570

574571
#[repr(C)]

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ impl Default for ApproxDistinct {
293293
```"#,
294294
standard_argument(name = "expression",)
295295
)]
296+
#[derive(PartialEq, Eq, Hash)]
296297
pub struct ApproxDistinct {
297298
signature: Signature,
298299
}

0 commit comments

Comments
 (0)