diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index d0a6e2be75e0b..d88d56a94ed7d 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -32,12 +32,13 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let trunc = trunc(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let return_field = Field::new("f", DataType::Float32, true).into(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { @@ -74,6 +75,51 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - to measure optimized performance + let scalar_f64_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float64(Some(std::f64::consts::PI)), + )]; + let scalar_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let scalar_return_field = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("trunc f64 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float32(Some(std::f32::consts::PI)), + )]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let scalar_f32_return_field = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("trunc f32 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_f32_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 6727ba8fbdf08..bd21eeef179d3 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -24,7 +24,7 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -110,7 +110,50 @@ impl ScalarUDFImpl for TruncFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(trunc, vec![])(&args.args) + // Extract precision from second argument (default 0) + let precision = match args.args.get(1) { + Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p), + Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision + Some(ColumnarValue::Array(_)) => { + // Precision is an array - use array path + return make_scalar_function(trunc, vec![])(&args.args); + } + None => Some(0), // default precision + Some(cv) => { + return exec_err!( + "trunc function requires precision to be Int64, got {:?}", + cv.data_type() + ); + } + }; + + // Scalar fast path using tuple matching for (value, precision) + match (&args.args[0], precision) { + // Null cases + (ColumnarValue::Scalar(sv), _) if sv.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + (_, None) => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + // Scalar cases + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 { + v.trunc() + } else { + compute_truncate64(*v, p) + }))), + ), + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 { + v.trunc() + } else { + compute_truncate32(*v, p) + }))), + ), + // Array path for everything else + _ => make_scalar_function(trunc, vec![])(&args.args), + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result {