Skip to content

Commit ac3a573

Browse files
authored
Derive UDAF equality from Eq, Hash (#17067)
* Require Eq to use udf_equals_hash The UDF comparison is expected to be reflexive. Require `Eq` for any uses of `udf_equals_hash` short-cut. * Add UdfEq wrapper around Arc to UDF impl The wrapper implements PartialEq, Eq, Hash by forwarding to UDF impl equals and hash_value functions. * Derive UDAF equality from Eq, Hash Reduce boilerplate in cases where implementation of `AggregateUDFImpl::{equals,hash_value}` can be derived using standard `Eq` and `Hash` traits.
1 parent f9efba0 commit ac3a573

File tree

29 files changed

+323
-387
lines changed

29 files changed

+323
-387
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
use std::any::Any;
2222
use std::collections::HashMap;
23-
use std::hash::{DefaultHasher, Hash, Hasher};
23+
use std::hash::{Hash, Hasher};
2424
use std::mem::{size_of, size_of_val};
2525
use std::sync::{
2626
atomic::{AtomicBool, Ordering},
@@ -55,7 +55,7 @@ use datafusion_common::{assert_contains, exec_datafusion_err};
5555
use datafusion_common::{cast::as_primitive_array, exec_err};
5656
use datafusion_expr::expr::WindowFunction;
5757
use datafusion_expr::{
58-
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr,
58+
col, create_udaf, function::AccumulatorArgs, udf_equals_hash, AggregateUDFImpl, Expr,
5959
GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition,
6060
};
6161
use datafusion_functions_aggregate::average::AvgAccumulator;
@@ -778,7 +778,7 @@ impl Accumulator for FirstSelector {
778778
}
779779
}
780780

781-
#[derive(Debug, Clone)]
781+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
782782
struct TestGroupsAccumulator {
783783
signature: Signature,
784784
result: u64,
@@ -817,20 +817,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
817817
Ok(Box::new(self.clone()))
818818
}
819819

820-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
821-
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
822-
self.result == other.result && self.signature == other.signature
823-
} else {
824-
false
825-
}
826-
}
827-
828-
fn hash_value(&self) -> u64 {
829-
let hasher = &mut DefaultHasher::new();
830-
self.signature.hash(hasher);
831-
self.result.hash(hasher);
832-
hasher.finish()
833-
}
820+
udf_equals_hash!(AggregateUDFImpl);
834821
}
835822

836823
impl Accumulator for TestGroupsAccumulator {
@@ -902,6 +889,32 @@ struct MetadataBasedAggregateUdf {
902889
metadata: HashMap<String, String>,
903890
}
904891

892+
impl PartialEq for MetadataBasedAggregateUdf {
893+
fn eq(&self, other: &Self) -> bool {
894+
let Self {
895+
name,
896+
signature,
897+
metadata,
898+
} = self;
899+
name == &other.name
900+
&& signature == &other.signature
901+
&& metadata == &other.metadata
902+
}
903+
}
904+
impl Eq for MetadataBasedAggregateUdf {}
905+
impl Hash for MetadataBasedAggregateUdf {
906+
fn hash<H: Hasher>(&self, state: &mut H) {
907+
let Self {
908+
name,
909+
signature,
910+
metadata: _, // unhashable
911+
} = self;
912+
std::any::type_name::<Self>().hash(state);
913+
name.hash(state);
914+
signature.hash(state);
915+
}
916+
}
917+
905918
impl MetadataBasedAggregateUdf {
906919
fn new(metadata: HashMap<String, String>) -> Self {
907920
// The name we return must be unique. Otherwise we will not call distinct
@@ -958,32 +971,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
958971
}))
959972
}
960973

961-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
962-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
963-
return false;
964-
};
965-
let Self {
966-
name,
967-
signature,
968-
metadata,
969-
} = self;
970-
name == &other.name
971-
&& signature == &other.signature
972-
&& metadata == &other.metadata
973-
}
974-
975-
fn hash_value(&self) -> u64 {
976-
let Self {
977-
name,
978-
signature,
979-
metadata: _, // unhashable
980-
} = self;
981-
let mut hasher = DefaultHasher::new();
982-
std::any::type_name::<Self>().hash(&mut hasher);
983-
name.hash(&mut hasher);
984-
signature.hash(&mut hasher);
985-
hasher.finish()
986-
}
974+
udf_equals_hash!(AggregateUDFImpl);
987975
}
988976

989977
#[derive(Debug)]

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async fn scalar_udf() -> Result<()> {
181181
Ok(())
182182
}
183183

184-
#[derive(PartialEq, Hash)]
184+
#[derive(PartialEq, Eq, Hash)]
185185
struct Simple0ArgsScalarUDF {
186186
name: String,
187187
signature: Signature,
@@ -492,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
492492
}
493493

494494
/// Volatile UDF that should append a different value to each row
495-
#[derive(Debug, PartialEq, Hash)]
495+
#[derive(Debug, PartialEq, Eq, Hash)]
496496
struct AddIndexToStringVolatileScalarUDF {
497497
name: String,
498498
signature: Signature,
@@ -941,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
941941
//
942942
// it also defines custom [ScalarUDFImpl::simplify()]
943943
// to replace ScalarUDF expression with one instance contains.
944-
#[derive(Debug, PartialEq, Hash)]
944+
#[derive(Debug, PartialEq, Eq, Hash)]
945945
struct ScalarFunctionWrapper {
946946
name: String,
947947
expr: Expr,
@@ -1221,6 +1221,7 @@ impl PartialEq for MyRegexUdf {
12211221
signature == &other.signature && regex.as_str() == other.regex.as_str()
12221222
}
12231223
}
1224+
impl Eq for MyRegexUdf {}
12241225

12251226
impl Hash for MyRegexUdf {
12261227
fn hash<H: Hasher>(&self, state: &mut H) {
@@ -1380,7 +1381,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
13801381
ctx.sql(sql).await?.collect().await
13811382
}
13821383

1383-
#[derive(Debug, PartialEq)]
1384+
#[derive(Debug, PartialEq, Eq)]
13841385
struct MetadataBasedUdf {
13851386
name: String,
13861387
signature: Signature,

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ impl OddCounter {
525525
}
526526

527527
fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
528-
#[derive(Debug, Clone, PartialEq, Hash)]
528+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
529529
struct SimpleWindowUDF {
530530
signature: Signature,
531531
test_state: PtrEq<Arc<TestState>>,

datafusion/doc/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
/// thus all text should be in English.
4040
///
4141
/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html
42-
#[derive(Debug, Clone, PartialEq, Hash)]
42+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4343
pub struct Documentation {
4444
/// The section in the documentation where the UDF will be documented
4545
pub doc_section: DocSection,
@@ -158,7 +158,7 @@ impl Documentation {
158158
}
159159
}
160160

161-
#[derive(Debug, Clone, PartialEq, Hash)]
161+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
162162
pub struct DocSection {
163163
/// True to include this doc section in the public
164164
/// documentation, false otherwise

datafusion/expr/src/async_udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ impl PartialEq for AsyncScalarUDF {
6969
arc_ptr_eq(inner, &other.inner)
7070
}
7171
}
72+
impl Eq for AsyncScalarUDF {}
7273

7374
impl Hash for AsyncScalarUDF {
7475
fn hash<H: Hasher>(&self, state: &mut H) {

datafusion/expr/src/expr_fn.rs

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
4545
use sqlparser::ast::NullTreatment;
4646
use std::any::Any;
4747
use std::fmt::Debug;
48-
use std::hash::{DefaultHasher, Hash, Hasher};
48+
use std::hash::Hash;
4949
use std::ops::Not;
5050
use std::sync::Arc;
5151

@@ -403,7 +403,7 @@ pub fn create_udf(
403403

404404
/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
405405
/// return type.
406-
#[derive(PartialEq, Hash)]
406+
#[derive(PartialEq, Eq, Hash)]
407407
pub struct SimpleScalarUDF {
408408
name: String,
409409
signature: Signature,
@@ -511,11 +511,12 @@ pub fn create_udaf(
511511

512512
/// Implements [`AggregateUDFImpl`] for functions that have a single signature and
513513
/// return type.
514+
#[derive(PartialEq, Eq, Hash)]
514515
pub struct SimpleAggregateUDF {
515516
name: String,
516517
signature: Signature,
517518
return_type: DataType,
518-
accumulator: AccumulatorFactoryFunction,
519+
accumulator: PtrEq<AccumulatorFactoryFunction>,
519520
state_fields: Vec<FieldRef>,
520521
}
521522

@@ -547,7 +548,7 @@ impl SimpleAggregateUDF {
547548
name,
548549
signature,
549550
return_type,
550-
accumulator,
551+
accumulator: accumulator.into(),
551552
state_fields,
552553
}
553554
}
@@ -566,7 +567,7 @@ impl SimpleAggregateUDF {
566567
name,
567568
signature,
568569
return_type,
569-
accumulator,
570+
accumulator: accumulator.into(),
570571
state_fields,
571572
}
572573
}
@@ -600,41 +601,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
600601
Ok(self.state_fields.clone())
601602
}
602603

603-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
604-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
605-
return false;
606-
};
607-
let Self {
608-
name,
609-
signature,
610-
return_type,
611-
accumulator,
612-
state_fields,
613-
} = self;
614-
name == &other.name
615-
&& signature == &other.signature
616-
&& return_type == &other.return_type
617-
&& Arc::ptr_eq(accumulator, &other.accumulator)
618-
&& state_fields == &other.state_fields
619-
}
620-
621-
fn hash_value(&self) -> u64 {
622-
let Self {
623-
name,
624-
signature,
625-
return_type,
626-
accumulator,
627-
state_fields,
628-
} = self;
629-
let mut hasher = DefaultHasher::new();
630-
std::any::type_name::<Self>().hash(&mut hasher);
631-
name.hash(&mut hasher);
632-
signature.hash(&mut hasher);
633-
return_type.hash(&mut hasher);
634-
Arc::as_ptr(accumulator).hash(&mut hasher);
635-
state_fields.hash(&mut hasher);
636-
hasher.finish()
637-
}
604+
udf_equals_hash!(AggregateUDFImpl);
638605
}
639606

640607
/// Creates a new UDWF with a specific signature, state type and return type.
@@ -661,7 +628,7 @@ pub fn create_udwf(
661628

662629
/// Implements [`WindowUDFImpl`] for functions that have a single signature and
663630
/// return type.
664-
#[derive(PartialEq, Hash)]
631+
#[derive(PartialEq, Eq, Hash)]
665632
pub struct SimpleWindowUDF {
666633
name: String,
667634
signature: Signature,

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ pub mod ptr_eq;
7171
pub mod test;
7272
pub mod tree_node;
7373
pub mod type_coercion;
74+
pub mod udf_eq;
7475
pub mod utils;
7576
pub mod var_provider;
7677
pub mod window_frame;

datafusion/expr/src/ptr_eq.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) {
3434
std::ptr::hash(Arc::as_ptr(a), hasher)
3535
}
3636

37-
/// A wrapper around a pointer that implements `PartialEq` and `Hash` comparing
37+
/// A wrapper around a pointer that implements `Eq` and `Hash` comparing
3838
/// the underlying pointer address.
3939
#[derive(Clone)]
4040
#[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse.
@@ -48,6 +48,7 @@ where
4848
arc_ptr_eq(&self.0, &other.0)
4949
}
5050
}
51+
impl<T> Eq for PtrEq<Arc<T>> where T: ?Sized {}
5152

5253
impl<T> Hash for PtrEq<Arc<T>>
5354
where

datafusion/expr/src/udaf.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ use crate::function::{
3838
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
3939
};
4040
use crate::groups_accumulator::GroupsAccumulator;
41+
use crate::udf_eq::UdfEq;
4142
use crate::utils::format_state_name;
4243
use crate::utils::AggregateOrderSensitivity;
43-
use crate::{expr_vec_fmt, Accumulator, Expr};
44+
use crate::{expr_vec_fmt, udf_equals_hash, Accumulator, Expr};
4445
use crate::{Documentation, Signature};
4546

4647
/// Logical representation of a user-defined [aggregate function] (UDAF).
@@ -1037,9 +1038,9 @@ pub enum ReversedUDAF {
10371038

10381039
/// AggregateUDF that adds an alias to the underlying function. It is better to
10391040
/// implement [`AggregateUDFImpl`], which supports aliases, directly if possible.
1040-
#[derive(Debug)]
1041+
#[derive(Debug, PartialEq, Eq, Hash)]
10411042
struct AliasedAggregateUDFImpl {
1042-
inner: Arc<dyn AggregateUDFImpl>,
1043+
inner: UdfEq<Arc<dyn AggregateUDFImpl>>,
10431044
aliases: Vec<String>,
10441045
}
10451046

@@ -1051,7 +1052,10 @@ impl AliasedAggregateUDFImpl {
10511052
let mut aliases = inner.aliases().to_vec();
10521053
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
10531054

1054-
Self { inner, aliases }
1055+
Self {
1056+
inner: inner.into(),
1057+
aliases,
1058+
}
10551059
}
10561060
}
10571061

@@ -1111,7 +1115,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
11111115
.map(|udf| {
11121116
udf.map(|udf| {
11131117
Arc::new(AliasedAggregateUDFImpl {
1114-
inner: udf,
1118+
inner: udf.into(),
11151119
aliases: self.aliases.clone(),
11161120
}) as Arc<dyn AggregateUDFImpl>
11171121
})
@@ -1134,20 +1138,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
11341138
self.inner.coerce_types(arg_types)
11351139
}
11361140

1137-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
1138-
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
1139-
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
1140-
} else {
1141-
false
1142-
}
1143-
}
1144-
1145-
fn hash_value(&self) -> u64 {
1146-
let hasher = &mut DefaultHasher::new();
1147-
self.inner.hash_value().hash(hasher);
1148-
self.aliases.hash(hasher);
1149-
hasher.finish()
1150-
}
1141+
udf_equals_hash!(AggregateUDFImpl);
11511142

11521143
fn is_descending(&self) -> Option<bool> {
11531144
self.inner.is_descending()

0 commit comments

Comments
 (0)