diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 94653d8864..625bde9956 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -80,6 +80,10 @@ harness = false name = "padding" harness = false +[[bench]] +name = "check_overflow" +harness = false + [[bench]] name = "date_trunc" harness = false diff --git a/native/spark-expr/benches/check_overflow.rs b/native/spark-expr/benches/check_overflow.rs new file mode 100644 index 0000000000..9351a3fee4 --- /dev/null +++ b/native/spark-expr/benches/check_overflow.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::Decimal128Builder; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::CheckOverflow; +use std::sync::Arc; + +fn create_decimal_batch(size: usize, precision: u8, scale: i8, with_nulls: bool) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Decimal128(precision, scale), + true, + )])); + let mut builder = Decimal128Builder::with_capacity(size); + + for i in 0..size { + if with_nulls && i % 10 == 0 { + builder.append_null(); + } else { + // Values that fit within precision 10 (max ~9999999999) + builder.append_value((i as i128) * 12345); + } + } + + let array = builder + .finish() + .with_precision_and_scale(precision, scale) + .unwrap(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +fn create_batch_with_overflow( + size: usize, + input_precision: u8, + target_precision: u8, + scale: i8, +) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Decimal128(input_precision, scale), + true, + )])); + let mut builder = Decimal128Builder::with_capacity(size); + + // Create values where ~10% will overflow the target precision + let max_for_target = 10i128.pow(target_precision as u32) - 1; + for i in 0..size { + if i % 10 == 0 { + // This value will overflow target precision + builder.append_value(max_for_target + (i as i128) + 1); + } else { + // This value is within target precision + builder.append_value((i as i128) % max_for_target); + } + } + + let array = builder + .finish() + .with_precision_and_scale(input_precision, scale) + .unwrap(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = [1000, 10000]; + + let mut group = c.benchmark_group("check_overflow"); + + for size in sizes { + // Benchmark: No overflow possible (precision already fits) + // This tests the fast path where input precision <= target precision + let batch_no_overflow = create_decimal_batch(size, 10, 2, false); + + // Create CheckOverflow that goes from precision 10 to 18 (no overflow possible) + let check_overflow_no_validation = Arc::new(CheckOverflow::new( + Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)), + DataType::Decimal128(18, 2), // larger precision = no overflow possible + false, + )); + + group.bench_with_input( + BenchmarkId::new("no_overflow_possible", size), + &batch_no_overflow, + |b, batch| { + b.iter(|| check_overflow_no_validation.evaluate(batch).unwrap()); + }, + ); + + // Benchmark: Validation needed, but no overflows occur + let batch_valid = create_decimal_batch(size, 18, 2, true); + let check_overflow_valid = Arc::new(CheckOverflow::new( + Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)), + DataType::Decimal128(10, 2), // smaller precision, need to validate + false, + )); + + group.bench_with_input( + BenchmarkId::new("validation_no_overflow", size), + &batch_valid, + |b, batch| { + b.iter(|| check_overflow_valid.evaluate(batch).unwrap()); + }, + ); + + // Benchmark: With ~10% overflows (requires null insertion) + let batch_with_overflow = create_batch_with_overflow(size, 18, 8, 2); + let check_overflow_with_nulls = Arc::new(CheckOverflow::new( + Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)), + DataType::Decimal128(8, 2), + false, + )); + + group.bench_with_input( + BenchmarkId::new("with_overflow_to_null", size), + &batch_with_overflow, + |b, batch| { + b.iter(|| check_overflow_with_nulls.evaluate(batch).unwrap()); + }, + ); + } + + group.finish(); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs index 9773a107af..f5affe39f7 100644 --- a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs +++ b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::is_valid_decimal_precision; use arrow::datatypes::{DataType, Schema}; use arrow::{ - array::{as_primitive_array, Array, ArrayRef, Decimal128Array}, - datatypes::{Decimal128Type, DecimalType}, + array::{as_primitive_array, Array}, + datatypes::Decimal128Type, record_batch::RecordBatch, }; use datafusion::common::{DataFusionError, ScalarValue}; @@ -101,8 +102,8 @@ impl PhysicalExpr for CheckOverflow { ColumnarValue::Array(array) if matches!(array.data_type(), DataType::Decimal128(_, _)) => { - let (precision, scale) = match &self.data_type { - DataType::Decimal128(p, s) => (p, s), + let (target_precision, target_scale) = match &self.data_type { + DataType::Decimal128(p, s) => (*p, *s), dt => { return Err(DataFusionError::Execution(format!( "CheckOverflow expects only Decimal128, but got {dt:?}" @@ -112,38 +113,74 @@ impl PhysicalExpr for CheckOverflow { let decimal_array = as_primitive_array::(&array); - let casted_array = if self.fail_on_error { - // Returning error if overflow - decimal_array.validate_decimal_precision(*precision)?; + let result_array = if self.fail_on_error { + // ANSI mode: validate and return error on overflow + // Use optimized validation that avoids error string allocation until needed + for i in 0..decimal_array.len() { + if decimal_array.is_valid(i) { + let value = decimal_array.value(i); + if !is_valid_decimal_precision(value, target_precision) { + return Err(DataFusionError::ArrowError( + Box::new(arrow::error::ArrowError::InvalidArgumentError( + format!( + "{} is not a valid Decimal128 value with precision {}", + value, target_precision + ), + )), + None, + )); + } + } + } + // Validation passed - just update metadata without copying data decimal_array + .clone() + .with_precision_and_scale(target_precision, target_scale)? } else { - // Overflowing gets null value - &decimal_array.null_if_overflow_precision(*precision) + // Legacy/Try mode: convert overflows to null + // Use Arrow's optimized null_if_overflow_precision which does a single pass + let result = decimal_array.null_if_overflow_precision(target_precision); + result.with_precision_and_scale(target_precision, target_scale)? }; - let new_array = Decimal128Array::from(casted_array.into_data()) - .with_precision_and_scale(*precision, *scale) - .map(|a| Arc::new(a) as ArrayRef)?; - - Ok(ColumnarValue::Array(new_array)) + Ok(ColumnarValue::Array(Arc::new(result_array))) } - ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => { - // `fail_on_error` is only true when ANSI is enabled, which we don't support yet - // (Java side will simply fallback to Spark when it is enabled) - assert!( - !self.fail_on_error, - "fail_on_error (ANSI mode) is not supported yet" - ); - - let new_v: Option = v.and_then(|v| { - Decimal128Type::validate_decimal_precision(v, precision, scale) - .map(|_| v) - .ok() - }); - - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - new_v, precision, scale, - ))) + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + let (target_precision, target_scale) = match &self.data_type { + DataType::Decimal128(p, s) => (*p, *s), + dt => { + return Err(DataFusionError::Execution(format!( + "CheckOverflow expects only Decimal128 for scalar, but got {dt:?}" + ))) + } + }; + + if self.fail_on_error { + if let Some(value) = v { + if !is_valid_decimal_precision(value, target_precision) { + return Err(DataFusionError::ArrowError( + Box::new(arrow::error::ArrowError::InvalidArgumentError(format!( + "{} is not a valid Decimal128 value with precision {}", + value, target_precision + ))), + None, + )); + } + } + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + v, + target_precision, + target_scale, + ))) + } else { + // Use optimized bool check instead of Result-returning validation + let new_v = v.filter(|&val| is_valid_decimal_precision(val, target_precision)); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + new_v, + target_precision, + target_scale, + ))) + } } v => Err(DataFusionError::Execution(format!( "CheckOverflow's child expression should be decimal array, but found {v:?}"