2020use std:: any:: Any ;
2121use std:: cmp:: Ordering ;
2222use std:: fmt:: { self , Debug , Formatter , Write } ;
23- use std:: hash:: { DefaultHasher , Hash , Hasher } ;
23+ use std:: hash:: { Hash , Hasher } ;
2424use std:: sync:: Arc ;
2525use std:: vec;
2626
2727use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2828
2929use datafusion_common:: { exec_err, not_impl_err, Result , ScalarValue , Statistics } ;
30+ use datafusion_expr_common:: dyn_eq:: { DynEq , DynHash } ;
3031use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
3132
3233use crate :: expr:: {
@@ -41,7 +42,7 @@ use crate::groups_accumulator::GroupsAccumulator;
4142use crate :: udf_eq:: UdfEq ;
4243use crate :: utils:: format_state_name;
4344use crate :: utils:: AggregateOrderSensitivity ;
44- use crate :: { expr_vec_fmt, udf_equals_hash , Accumulator , Expr } ;
45+ use crate :: { expr_vec_fmt, Accumulator , Expr } ;
4546use crate :: { Documentation , Signature } ;
4647
4748/// Logical representation of a user-defined [aggregate function] (UDAF).
@@ -82,15 +83,15 @@ pub struct AggregateUDF {
8283
8384impl PartialEq for AggregateUDF {
8485 fn eq ( & self , other : & Self ) -> bool {
85- self . inner . equals ( other. inner . as_ref ( ) )
86+ self . inner . dyn_eq ( other. inner . as_any ( ) )
8687 }
8788}
8889
8990impl Eq for AggregateUDF { }
9091
9192impl Hash for AggregateUDF {
9293 fn hash < H : Hasher > ( & self , state : & mut H ) {
93- self . inner . hash_value ( ) . hash ( state)
94+ self . inner . dyn_hash ( state)
9495 }
9596}
9697
@@ -373,7 +374,7 @@ where
373374/// # use arrow::datatypes::Schema;
374375/// # use arrow::datatypes::Field;
375376///
376- /// #[derive(Debug, Clone)]
377+ /// #[derive(Debug, Clone, PartialEq, Eq, Hash )]
377378/// struct GeoMeanUdf {
378379/// signature: Signature,
379380/// }
@@ -426,7 +427,7 @@ where
426427/// // Call the function `geo_mean(col)`
427428/// let expr = geometric_mean.call(vec![col("a")]);
428429/// ```
429- pub trait AggregateUDFImpl : Debug + Send + Sync {
430+ pub trait AggregateUDFImpl : Debug + DynEq + DynHash + Send + Sync {
430431 /// Returns this object as an [`Any`] trait object
431432 fn as_any ( & self ) -> & dyn Any ;
432433
@@ -914,41 +915,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
914915 not_impl_err ! ( "Function {} does not implement coerce_types" , self . name( ) )
915916 }
916917
917- /// Return true if this aggregate UDF is equal to the other.
918- ///
919- /// Allows customizing the equality of aggregate UDFs.
920- /// *Must* be implemented explicitly if the UDF type has internal state.
921- /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
922- ///
923- /// - reflexive: `a.equals(a)`;
924- /// - symmetric: `a.equals(b)` implies `b.equals(a)`;
925- /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
926- ///
927- /// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
928- fn equals ( & self , other : & dyn AggregateUDFImpl ) -> bool {
929- self . as_any ( ) . type_id ( ) == other. as_any ( ) . type_id ( )
930- && self . name ( ) == other. name ( )
931- && self . aliases ( ) == other. aliases ( )
932- && self . signature ( ) == other. signature ( )
933- }
934-
935- /// Returns a hash value for this aggregate UDF.
936- ///
937- /// Allows customizing the hash code of aggregate UDFs.
938- /// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
939- ///
940- /// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
941- /// their `hash_value`s must be the same.
942- ///
943- /// By default, it only hashes the type. The other fields are not hashed, as usually the
944- /// name, signature, and aliases are implied by the UDF type. Recall that UDFs with state
945- /// (and thus possibly changing fields) must override [`Self::equals`] and [`Self::hash_value`].
946- fn hash_value ( & self ) -> u64 {
947- let hasher = & mut DefaultHasher :: new ( ) ;
948- self . as_any ( ) . type_id ( ) . hash ( hasher) ;
949- hasher. finish ( )
950- }
951-
952918 /// If this function is max, return true
953919 /// If the function is min, return false
954920 /// Otherwise return None (the default)
@@ -1008,10 +974,11 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
1008974
1009975impl PartialEq for dyn AggregateUDFImpl {
1010976 fn eq ( & self , other : & Self ) -> bool {
1011- self . equals ( other)
977+ self . dyn_eq ( other. as_any ( ) )
1012978 }
1013979}
1014980
981+ // TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `dyn AggregateUDFImpl` and it should be
1015982// Manual implementation of `PartialOrd`
1016983// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
1017984// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
@@ -1194,8 +1161,6 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
11941161 self . inner . set_monotonicity ( data_type)
11951162 }
11961163
1197- udf_equals_hash ! ( AggregateUDFImpl ) ;
1198-
11991164 fn documentation ( & self ) -> Option < & Documentation > {
12001165 self . inner . documentation ( )
12011166 }
@@ -1266,7 +1231,7 @@ mod test {
12661231 use std:: any:: Any ;
12671232 use std:: cmp:: Ordering ;
12681233
1269- #[ derive( Debug , Clone ) ]
1234+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
12701235 struct AMeanUdf {
12711236 signature : Signature ,
12721237 }
@@ -1307,7 +1272,7 @@ mod test {
13071272 }
13081273 }
13091274
1310- #[ derive( Debug , Clone ) ]
1275+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
13111276 struct BMeanUdf {
13121277 signature : Signature ,
13131278 }
@@ -1347,6 +1312,15 @@ mod test {
13471312 }
13481313 }
13491314
1315+ #[ test]
1316+ fn test_partial_eq ( ) {
1317+ let a1 = AggregateUDF :: from ( AMeanUdf :: new ( ) ) ;
1318+ let a2 = AggregateUDF :: from ( AMeanUdf :: new ( ) ) ;
1319+ let eq = a1 == a2;
1320+ assert ! ( eq) ;
1321+ assert_eq ! ( a1, a2) ;
1322+ }
1323+
13501324 #[ test]
13511325 fn test_partial_ord ( ) {
13521326 // Test validates that partial ord is defined for AggregateUDF using the name and signature,
0 commit comments