Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ harness = false
name = "padding"
harness = false

[[bench]]
name = "check_overflow"
harness = false

[[bench]]
name = "date_trunc"
harness = false
Expand Down
152 changes: 152 additions & 0 deletions native/spark-expr/benches/check_overflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::builder::Decimal128Builder;
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_spark_expr::CheckOverflow;
use std::sync::Arc;

fn create_decimal_batch(size: usize, precision: u8, scale: i8, with_nulls: bool) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Decimal128(precision, scale),
true,
)]));
let mut builder = Decimal128Builder::with_capacity(size);

for i in 0..size {
if with_nulls && i % 10 == 0 {
builder.append_null();
} else {
// Values that fit within precision 10 (max ~9999999999)
builder.append_value((i as i128) * 12345);
}
}

let array = builder
.finish()
.with_precision_and_scale(precision, scale)
.unwrap();
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}

fn create_batch_with_overflow(
size: usize,
input_precision: u8,
target_precision: u8,
scale: i8,
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Decimal128(input_precision, scale),
true,
)]));
let mut builder = Decimal128Builder::with_capacity(size);

// Create values where ~10% will overflow the target precision
let max_for_target = 10i128.pow(target_precision as u32) - 1;
for i in 0..size {
if i % 10 == 0 {
// This value will overflow target precision
builder.append_value(max_for_target + (i as i128) + 1);
} else {
// This value is within target precision
builder.append_value((i as i128) % max_for_target);
}
}

let array = builder
.finish()
.with_precision_and_scale(input_precision, scale)
.unwrap();
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}

fn criterion_benchmark(c: &mut Criterion) {
let sizes = [1000, 10000];

let mut group = c.benchmark_group("check_overflow");

for size in sizes {
// Benchmark: No overflow possible (precision already fits)
// This tests the fast path where input precision <= target precision
let batch_no_overflow = create_decimal_batch(size, 10, 2, false);

// Create CheckOverflow that goes from precision 10 to 18 (no overflow possible)
let check_overflow_no_validation = Arc::new(CheckOverflow::new(
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
DataType::Decimal128(18, 2), // larger precision = no overflow possible
false,
));

group.bench_with_input(
BenchmarkId::new("no_overflow_possible", size),
&batch_no_overflow,
|b, batch| {
b.iter(|| check_overflow_no_validation.evaluate(batch).unwrap());
},
);

// Benchmark: Validation needed, but no overflows occur
let batch_valid = create_decimal_batch(size, 18, 2, true);
let check_overflow_valid = Arc::new(CheckOverflow::new(
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
DataType::Decimal128(10, 2), // smaller precision, need to validate
false,
));

group.bench_with_input(
BenchmarkId::new("validation_no_overflow", size),
&batch_valid,
|b, batch| {
b.iter(|| check_overflow_valid.evaluate(batch).unwrap());
},
);

// Benchmark: With ~10% overflows (requires null insertion)
let batch_with_overflow = create_batch_with_overflow(size, 18, 8, 2);
let check_overflow_with_nulls = Arc::new(CheckOverflow::new(
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
DataType::Decimal128(8, 2),
false,
));

group.bench_with_input(
BenchmarkId::new("with_overflow_to_null", size),
&batch_with_overflow,
|b, batch| {
b.iter(|| check_overflow_with_nulls.evaluate(batch).unwrap());
},
);
}

group.finish();
}

fn config() -> Criterion {
Criterion::default()
}

criterion_group! {
name = benches;
config = config();
targets = criterion_benchmark
}
criterion_main!(benches);
99 changes: 68 additions & 31 deletions native/spark-expr/src/math_funcs/internal/checkoverflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::utils::is_valid_decimal_precision;
use arrow::datatypes::{DataType, Schema};
use arrow::{
array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
datatypes::{Decimal128Type, DecimalType},
array::{as_primitive_array, Array},
datatypes::Decimal128Type,
record_batch::RecordBatch,
};
use datafusion::common::{DataFusionError, ScalarValue};
Expand Down Expand Up @@ -101,8 +102,8 @@ impl PhysicalExpr for CheckOverflow {
ColumnarValue::Array(array)
if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
{
let (precision, scale) = match &self.data_type {
DataType::Decimal128(p, s) => (p, s),
let (target_precision, target_scale) = match &self.data_type {
DataType::Decimal128(p, s) => (*p, *s),
dt => {
return Err(DataFusionError::Execution(format!(
"CheckOverflow expects only Decimal128, but got {dt:?}"
Expand All @@ -112,38 +113,74 @@ impl PhysicalExpr for CheckOverflow {

let decimal_array = as_primitive_array::<Decimal128Type>(&array);

let casted_array = if self.fail_on_error {
// Returning error if overflow
decimal_array.validate_decimal_precision(*precision)?;
let result_array = if self.fail_on_error {
// ANSI mode: validate and return error on overflow
// Use optimized validation that avoids error string allocation until needed
for i in 0..decimal_array.len() {
if decimal_array.is_valid(i) {
let value = decimal_array.value(i);
if !is_valid_decimal_precision(value, target_precision) {
return Err(DataFusionError::ArrowError(
Box::new(arrow::error::ArrowError::InvalidArgumentError(
format!(
"{} is not a valid Decimal128 value with precision {}",
value, target_precision
),
)),
None,
));
}
}
}
// Validation passed - just update metadata without copying data
decimal_array
.clone()
.with_precision_and_scale(target_precision, target_scale)?
} else {
// Overflowing gets null value
&decimal_array.null_if_overflow_precision(*precision)
// Legacy/Try mode: convert overflows to null
// Use Arrow's optimized null_if_overflow_precision which does a single pass
let result = decimal_array.null_if_overflow_precision(target_precision);
result.with_precision_and_scale(target_precision, target_scale)?
};

let new_array = Decimal128Array::from(casted_array.into_data())
.with_precision_and_scale(*precision, *scale)
.map(|a| Arc::new(a) as ArrayRef)?;

Ok(ColumnarValue::Array(new_array))
Ok(ColumnarValue::Array(Arc::new(result_array)))
}
ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
// `fail_on_error` is only true when ANSI is enabled, which we don't support yet
// (Java side will simply fallback to Spark when it is enabled)
assert!(
!self.fail_on_error,
"fail_on_error (ANSI mode) is not supported yet"
);

let new_v: Option<i128> = v.and_then(|v| {
Decimal128Type::validate_decimal_precision(v, precision, scale)
.map(|_| v)
.ok()
});

Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
new_v, precision, scale,
)))
ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
let (target_precision, target_scale) = match &self.data_type {
DataType::Decimal128(p, s) => (*p, *s),
dt => {
return Err(DataFusionError::Execution(format!(
"CheckOverflow expects only Decimal128 for scalar, but got {dt:?}"
)))
}
};

if self.fail_on_error {
if let Some(value) = v {
if !is_valid_decimal_precision(value, target_precision) {
return Err(DataFusionError::ArrowError(
Box::new(arrow::error::ArrowError::InvalidArgumentError(format!(
"{} is not a valid Decimal128 value with precision {}",
value, target_precision
))),
None,
));
}
}
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
v,
target_precision,
target_scale,
)))
} else {
// Use optimized bool check instead of Result-returning validation
let new_v = v.filter(|&val| is_valid_decimal_precision(val, target_precision));
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
new_v,
target_precision,
target_scale,
)))
}
}
v => Err(DataFusionError::Execution(format!(
"CheckOverflow's child expression should be decimal array, but found {v:?}"
Expand Down
Loading