Skip to content

Commit 1961fe7

Browse files
authored
Derive ScalarUDFImpl equality, hash from Eq, Hash traits (#17164)
* Remove redundant JsonGetStr::aliases field * Derive `ScalarUDFImpl` equality, hash from `Eq`, `Hash` traits Follows similar change for `WindowUDFImpl` and `AggregateUDFImpl`, i.e. the 8494a39 and b8bf7c5 commits. Previously, the `ScalarUDFImpl` trait contained `equals` and `hash_value` methods with contracts following the `Eq` and `Hash` traits. However, the existence of default implementations of these methods made it error-prone, with many functions (scalar, aggregate, window) missing to customize the equals even though they ought to. There is no fix to this that's not an API breaking change, so a breaking change is warranted. Removing the default implementations would be enough of a solution, but at the cost of a lot of boilerplate needed in implementations. Instead, this removes the methods from the trait, and reuses `DynEq`, `DynHash` traits used previously only for physical expressions. This allows for functions to provide their implementations using no more than `#[derive(PartialEq, Eq, Hash)]` in a typical case.
1 parent c2714db commit 1961fe7

File tree

164 files changed

+300
-497
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+300
-497
lines changed

datafusion-examples/examples/advanced_udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ use datafusion::prelude::*;
3939
/// the power of the second argument `a^b`.
4040
///
4141
/// To do so, we must implement the `ScalarUDFImpl` trait.
42-
#[derive(Debug, Clone)]
42+
#[derive(Debug, PartialEq, Eq, Hash)]
4343
struct PowUdf {
4444
signature: Signature,
4545
aliases: Vec<String>,

datafusion-examples/examples/async_udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn animal() -> Result<RecordBatch> {
133133
///
134134
/// Since this is a simplified example, it does not call an LLM service, but
135135
/// could be extended to do so in a real-world scenario.
136-
#[derive(Debug)]
136+
#[derive(Debug, PartialEq, Eq, Hash)]
137137
struct AskLLM {
138138
signature: Signature,
139139
}

datafusion-examples/examples/function_factory.rs

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use datafusion::logical_expr::{
2828
ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
2929
Signature, Volatility,
3030
};
31-
use std::hash::{DefaultHasher, Hash, Hasher};
31+
use std::hash::Hash;
3232
use std::result::Result as RResult;
3333
use std::sync::Arc;
3434

@@ -107,7 +107,7 @@ impl FunctionFactory for CustomFunctionFactory {
107107
}
108108

109109
/// this function represents the newly created execution engine.
110-
#[derive(Debug)]
110+
#[derive(Debug, PartialEq, Eq, Hash)]
111111
struct ScalarFunctionWrapper {
112112
/// The text of the function body, `$1 + f1($2)` in our example
113113
name: String,
@@ -154,38 +154,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
154154
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
155155
Ok(SortProperties::Unordered)
156156
}
157-
158-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
159-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
160-
return false;
161-
};
162-
let Self {
163-
name,
164-
expr,
165-
signature,
166-
return_type,
167-
} = self;
168-
name == &other.name
169-
&& expr == &other.expr
170-
&& signature == &other.signature
171-
&& return_type == &other.return_type
172-
}
173-
174-
fn hash_value(&self) -> u64 {
175-
let Self {
176-
name,
177-
expr,
178-
signature,
179-
return_type,
180-
} = self;
181-
let mut hasher = DefaultHasher::new();
182-
std::any::type_name::<Self>().hash(&mut hasher);
183-
name.hash(&mut hasher);
184-
expr.hash(&mut hasher);
185-
signature.hash(&mut hasher);
186-
return_type.hash(&mut hasher);
187-
hasher.finish()
188-
}
189157
}
190158

191159
impl ScalarFunctionWrapper {

datafusion-examples/examples/json_shredding.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,17 +282,15 @@ impl TableProvider for ExampleTableProvider {
282282
}
283283

284284
/// Scalar UDF that uses serde_json to access json fields
285-
#[derive(Debug)]
285+
#[derive(Debug, PartialEq, Eq, Hash)]
286286
pub struct JsonGetStr {
287287
signature: Signature,
288-
aliases: [String; 1],
289288
}
290289

291290
impl Default for JsonGetStr {
292291
fn default() -> Self {
293292
Self {
294293
signature: Signature::variadic_any(Volatility::Immutable),
295-
aliases: ["json_get_str".to_string()],
296294
}
297295
}
298296
}
@@ -303,7 +301,7 @@ impl ScalarUDFImpl for JsonGetStr {
303301
}
304302

305303
fn name(&self) -> &str {
306-
self.aliases[0].as_str()
304+
"json_get_str"
307305
}
308306

309307
fn signature(&self) -> &Signature {
@@ -355,10 +353,6 @@ impl ScalarUDFImpl for JsonGetStr {
355353
.collect::<StringArray>();
356354
Ok(ColumnarValue::Array(Arc::new(values)))
357355
}
358-
359-
fn aliases(&self) -> &[String] {
360-
&self.aliases
361-
}
362356
}
363357

364358
/// Factory for creating ShreddedJsonRewriter instances

datafusion-examples/examples/optimizer_rule.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ fn is_lit_or_col(expr: &Expr) -> bool {
175175
}
176176

177177
/// A simple user defined filter function
178-
#[derive(Debug, Clone)]
178+
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
179179
struct MyEq {
180180
signature: Signature,
181181
}

datafusion/core/tests/fuzz_cases/equivalence/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ fn get_sort_columns(
512512
.collect::<Result<Vec<_>>>()
513513
}
514514

515-
#[derive(Debug, Clone)]
515+
#[derive(Debug, PartialEq, Eq, Hash)]
516516
pub struct TestScalarUDF {
517517
pub(crate) signature: Signature,
518518
}

datafusion/core/tests/physical_optimizer/projection_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use insta::assert_snapshot;
6363
use itertools::Itertools;
6464

6565
/// Mocked UDF
66-
#[derive(Debug)]
66+
#[derive(Debug, PartialEq, Eq, Hash)]
6767
struct DummyUDF {
6868
signature: Signature,
6969
}

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ use datafusion_common::{
4343
use datafusion_expr::expr::FieldMetadata;
4444
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
4545
use datafusion_expr::{
46-
lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction,
47-
CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
48-
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
46+
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
47+
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
48+
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
4949
};
5050
use datafusion_functions_nested::range::range_udf;
5151
use parking_lot::Mutex;
@@ -218,8 +218,6 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
218218
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
219219
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
220220
}
221-
222-
udf_equals_hash!(ScalarUDFImpl);
223221
}
224222

225223
#[tokio::test]
@@ -560,8 +558,6 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
560558
};
561559
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
562560
}
563-
564-
udf_equals_hash!(ScalarUDFImpl);
565561
}
566562

567563
#[tokio::test]
@@ -665,7 +661,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> {
665661
Ok(())
666662
}
667663

668-
#[derive(Debug)]
664+
#[derive(Debug, PartialEq, Eq, Hash)]
669665
struct CastToI64UDF {
670666
signature: Signature,
671667
}
@@ -787,7 +783,7 @@ async fn deregister_udf() -> Result<()> {
787783
Ok(())
788784
}
789785

790-
#[derive(Debug)]
786+
#[derive(Debug, PartialEq, Eq, Hash)]
791787
struct TakeUDF {
792788
signature: Signature,
793789
}
@@ -979,8 +975,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
979975

980976
Ok(ExprSimplifyResult::Simplified(replacement))
981977
}
982-
983-
udf_equals_hash!(ScalarUDFImpl);
984978
}
985979

986980
impl ScalarFunctionWrapper {
@@ -1282,8 +1276,6 @@ impl ScalarUDFImpl for MyRegexUdf {
12821276
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
12831277
}
12841278
}
1285-
1286-
udf_equals_hash!(ScalarUDFImpl);
12871279
}
12881280

12891281
#[tokio::test]
@@ -1471,8 +1463,6 @@ impl ScalarUDFImpl for MetadataBasedUdf {
14711463
}
14721464
}
14731465
}
1474-
1475-
udf_equals_hash!(ScalarUDFImpl);
14761466
}
14771467

14781468
#[tokio::test]
@@ -1611,7 +1601,7 @@ async fn test_metadata_based_udf_with_literal() -> Result<()> {
16111601
/// sides. For the input, we will handle the data differently if there is
16121602
/// the canonical extension type Bool8. For the output we will add a
16131603
/// user defined extension type.
1614-
#[derive(Debug)]
1604+
#[derive(Debug, PartialEq, Eq, Hash)]
16151605
struct ExtensionBasedUdf {
16161606
name: String,
16171607
signature: Signature,
@@ -1790,7 +1780,7 @@ async fn test_extension_based_udf() -> Result<()> {
17901780

17911781
#[tokio::test]
17921782
async fn test_config_options_work_for_scalar_func() -> Result<()> {
1793-
#[derive(Debug)]
1783+
#[derive(Debug, PartialEq, Eq, Hash)]
17941784
struct TestScalarUDF {
17951785
signature: Signature,
17961786
}

datafusion/expr/src/async_udf.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
// under the License.
1717

1818
use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash};
19-
use crate::{
20-
udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
21-
};
19+
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
2220
use arrow::datatypes::{DataType, FieldRef};
2321
use async_trait::async_trait;
2422
use datafusion_common::error::Result;
@@ -127,8 +125,6 @@ impl ScalarUDFImpl for AsyncScalarUDF {
127125
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
128126
internal_err!("async functions should not be called directly")
129127
}
130-
131-
udf_equals_hash!(ScalarUDFImpl);
132128
}
133129

134130
impl Display for AsyncScalarUDF {

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3648,7 +3648,7 @@ mod test {
36483648
#[test]
36493649
fn test_is_volatile_scalar_func() {
36503650
// UDF
3651-
#[derive(Debug)]
3651+
#[derive(Debug, PartialEq, Eq, Hash)]
36523652
struct TestScalarUDF {
36533653
signature: Signature,
36543654
}

0 commit comments

Comments
 (0)