From 9b6882fd8066a1f90f571a9d8a85fa44ecf843b5 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 15 Jan 2026 12:57:19 +0530 Subject: [PATCH] perf: Optimize round scalar performance --- datafusion/functions/Cargo.toml | 5 + datafusion/functions/benches/round.rs | 154 +++++++++++++++++++++++++ datafusion/functions/src/math/round.rs | 44 +++++++ 3 files changed, 203 insertions(+) create mode 100644 datafusion/functions/benches/round.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 939fcfd11fba0..610ab1617a8d1 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -320,3 +320,8 @@ required-features = ["math_expressions"] harness = false name = "floor_ceil" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "round" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/round.rs b/datafusion/functions/benches/round.rs new file mode 100644 index 0000000000000..ea59584919d68 --- /dev/null +++ b/datafusion/functions/benches/round.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::round; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let round_fn = round(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("round size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ + ColumnarValue::Array(f64_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Float32 array benchmark + let f32_array = Arc::new(create_primitive_array::(size, 0.1)); + let f32_args = vec![ + ColumnarValue::Array(f32_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_array", |b| { + b.iter(|| { + let args_cloned = f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index de70788128b88..c2e856bcd8e39 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -141,6 +141,50 @@ impl ScalarUDFImpl for RoundFunc { &default_decimal_places }; + // Scalar fast path for float types - avoid array conversion overhead + if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = + (&args.args[0], decimal_places) + { + // Extract decimal places as i32 + let dp = match dp_scalar { + ScalarValue::Int32(Some(dp)) => *dp, + ScalarValue::Int32(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + _ => { + // Fall through to array path for non-Int32 decimal places + return round_columnar( + &args.args[0], + decimal_places, + args.number_rows, + ); + } + }; + + match value_scalar { + ScalarValue::Float64(Some(v)) => { + let factor = 10_f64.powi(dp); + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + (v * factor).round() / factor, + )))); + } + ScalarValue::Float64(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + ScalarValue::Float32(Some(v)) => { + let factor = 10_f32.powi(dp); + return Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + (v * factor).round() / factor, + )))); + } + ScalarValue::Float32(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32(None))); + } + // For decimals and other types: fall through to array path + _ => {} + } + } + round_columnar(&args.args[0], decimal_places, args.number_rows) }