Skip to content

Commit dad564f

Browse files
authored
chore: move udf registration to better place (#1899)
1 parent ca6f113 commit dad564f

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ use crate::{
5151
use datafusion::common::ScalarValue;
5252
use datafusion::execution::disk_manager::DiskManagerMode;
5353
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
54+
use datafusion::logical_expr::ScalarUDF;
5455
use datafusion_comet_proto::spark_operator::Operator;
56+
use datafusion_spark::function::math::expm1::SparkExpm1;
5557
use futures::stream::StreamExt;
5658
use jni::objects::JByteBuffer;
5759
use jni::sys::JNI_FALSE;
@@ -284,6 +286,12 @@ fn prepare_datafusion_session_context(
284286

285287
datafusion::functions_nested::register_all(&mut session_ctx)?;
286288

289+
// register UDFs from datafusion-spark crate
290+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
291+
292+
// Must be the last one to override existing functions with the same name
293+
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;
294+
287295
Ok(session_ctx)
288296
}
289297

native/core/src/execution/planner.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ use datafusion::{
6565
prelude::SessionContext,
6666
};
6767
use datafusion_comet_spark_expr::{
68-
create_comet_physical_fun, create_negate_expr, SparkBitwiseCount, SparkBitwiseNot,
69-
SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
68+
create_comet_physical_fun, create_negate_expr, SparkHour, SparkMinute, SparkSecond,
7069
};
7170

7271
use crate::execution::operators::ExecutionError::GeneralError;
@@ -110,7 +109,6 @@ use datafusion_comet_spark_expr::{
110109
NormalizeNaNAndZero, RLike, SparkCastOptions, StartsWith, Stddev, StringSpaceExpr,
111110
SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance,
112111
};
113-
use datafusion_spark::function::math::expm1::SparkExpm1;
114112
use itertools::Itertools;
115113
use jni::objects::GlobalRef;
116114
use num::{BigInt, ToPrimitive};
@@ -154,11 +152,6 @@ impl Default for PhysicalPlanner {
154152

155153
impl PhysicalPlanner {
156154
pub fn new(session_ctx: Arc<SessionContext>) -> Self {
157-
// register UDFs from datafusion-spark crate
158-
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
159-
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
160-
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default()));
161-
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateTrunc::default()));
162155
Self {
163156
exec_context_id: TEST_EXEC_CONTEXT_ID,
164157
session_ctx,

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{
2020
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
2121
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
2222
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
23-
SparkChrFunc,
23+
SparkBitwiseCount, SparkBitwiseNot, SparkChrFunc, SparkDateTrunc,
2424
};
2525
use arrow::datatypes::DataType;
2626
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -112,7 +112,6 @@ pub fn create_comet_physical_fun(
112112
let func = Arc::new(spark_xxhash64);
113113
make_comet_scalar_udf!("xxhash64", func, without data_type)
114114
}
115-
"chr" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default()))),
116115
"isnan" => {
117116
let func = Arc::new(spark_isnan);
118117
make_comet_scalar_udf!("isnan", func, without data_type)
@@ -153,6 +152,25 @@ pub fn create_comet_physical_fun(
153152
}
154153
}
155154

155+
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
156+
vec![
157+
Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default())),
158+
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseNot::default())),
159+
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
160+
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
161+
]
162+
}
163+
164+
/// Registers all custom UDFs
165+
pub fn register_all_comet_functions(registry: &mut dyn FunctionRegistry) -> DataFusionResult<()> {
166+
// This will override existing UDFs with the same name
167+
all_scalar_functions()
168+
.into_iter()
169+
.try_for_each(|udf| registry.register_udf(udf).map(|_| ()))?;
170+
171+
Ok(())
172+
}
173+
156174
struct CometScalarFunction {
157175
name: String,
158176
signature: Signature,

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ pub use bitwise_funcs::*;
5858
pub use conditional_funcs::*;
5959
pub use conversion_funcs::*;
6060

61-
pub use comet_scalar_funcs::create_comet_physical_fun;
61+
pub use comet_scalar_funcs::{create_comet_physical_fun, register_all_comet_functions};
6262
pub use datetime_funcs::{
6363
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
6464
TimestampTruncExpr,

0 commit comments

Comments
 (0)