From 8ccd4aa5eb997af1468f51cb62b350ec3d93e3a5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 12:43:58 -0700 Subject: [PATCH 1/8] use trim_ascii --- .../spark-expr/src/conversion_funcs/cast.rs | 33 ++++--------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 314beb18ca..2e3c031ab6 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2009,21 +2009,6 @@ fn cast_string_to_int_with_range_check( } } -// Returns (start, end) indices after trimming whitespace -fn trim_whitespace(bytes: &[u8]) -> (usize, usize) { - let mut start = 0; - let mut end = bytes.len(); - - while start < end && bytes[start].is_ascii_whitespace() { - start += 1; - } - while end > start && bytes[end - 1].is_ascii_whitespace() { - end -= 1; - } - - (start, end) -} - // Parses sign and returns (is_negative, start_idx after sign) // Returns None if invalid (e.g., just "+" or "-") fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> { @@ -2052,13 +2037,11 @@ fn do_parse_string_to_int_legacy str: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); + let trimmed_bytes = str.as_bytes().trim_ascii(); - if start == end { + if trimmed_bytes.is_empty() { return Ok(None); } - let trimmed_bytes = &bytes[start..end]; let (negative, idx) = match parse_sign(trimmed_bytes) { Some(result) => result, @@ -2122,13 +2105,11 @@ fn do_parse_string_to_int_ansi + type_name: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); + let trimmed_bytes = str.as_bytes().trim_ascii(); - if start == end { + if trimmed_bytes.is_empty() { return Err(invalid_value(str, "STRING", type_name)); } - let trimmed_bytes = &bytes[start..end]; let (negative, idx) = match parse_sign(trimmed_bytes) { Some(result) => result, @@ -2181,13 +2162,11 @@ fn do_parse_string_to_int_try + str: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); + let trimmed_bytes = str.as_bytes().trim_ascii(); - if start == end { + if trimmed_bytes.is_empty() { return Ok(None); } - let trimmed_bytes = &bytes[start..end]; let (negative, idx) = match parse_sign(trimmed_bytes) { Some(result) => result, From 06ae139ed4f09ec6f9aedb7352eb154608e80ecb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 12:55:35 -0700 Subject: [PATCH 2/8] minor --- native/spark-expr/src/conversion_funcs/cast.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 2e3c031ab6..9417114dfc 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2066,12 +2066,11 @@ fn do_parse_string_to_int_legacy return Ok(None); } - let digit: T = T::from(ch - b'0'); - if result < stop_value { return Ok(None); } let v = result * radix; + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { @@ -2130,12 +2129,11 @@ fn do_parse_string_to_int_ansi + return Err(invalid_value(str, "STRING", type_name)); } - let digit: T = T::from(ch - b'0'); - if result < stop_value { return Err(invalid_value(str, "STRING", type_name)); } let v = result * radix; + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { @@ -2188,12 +2186,11 @@ fn do_parse_string_to_int_try + return Ok(None); } - let digit: T = T::from(ch - b'0'); - if result < stop_value { return Ok(None); } let v = result * radix; + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, _ => { From cbf68bb1b81f69014f2208386d6c1f48c8400ff8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 13:00:02 -0700 Subject: [PATCH 3/8] use two loops and remove mutable parse_sign_and_digits variable --- .../spark-expr/src/conversion_funcs/cast.rs | 53 +++++++++---------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 9417114dfc..00db5a6381 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2049,39 +2049,36 @@ fn do_parse_string_to_int_legacy }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - let mut parse_sign_and_digits = true; - for &ch in &trimmed_bytes[idx..] { - if parse_sign_and_digits { - if ch == b'.' { - // truncate decimal in legacy mode - parse_sign_and_digits = false; - continue; - } + let mut iter = trimmed_bytes[idx..].iter(); - if !ch.is_ascii_digit() { - return Ok(None); - } + // Parse integer portion until '.' or end + for &ch in iter.by_ref() { + if ch == b'.' { + break; + } - if result < stop_value { - return Ok(None); - } - let v = result * radix; - let digit: T = T::from(ch - b'0'); - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { - return Ok(None); - } - } - } else { - // in legacy mode we still process chars after the dot and make sure the chars are digits - if !ch.is_ascii_digit() { - return Ok(None); - } + if !ch.is_ascii_digit() { + return Ok(None); + } + + if result < stop_value { + return Ok(None); + } + let v = result * radix; + let digit: T = T::from(ch - b'0'); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => return Ok(None), + } + } + + // Validate decimal portion (digits only, values ignored) + for &ch in iter { + if !ch.is_ascii_digit() { + return Ok(None); } } From e28d52ae7aadd4539fda41e69a15c17af25d7b16 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 13:25:24 -0700 Subject: [PATCH 4/8] remove per-row eval mode check and expand benchmarks --- native/spark-expr/benches/cast_from_string.rs | 74 +++++--- .../spark-expr/src/conversion_funcs/cast.rs | 171 +++++++++++------- 2 files changed, 155 insertions(+), 90 deletions(-) diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 990cdec213..74b2c77a3d 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -23,45 +23,75 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { - let batch = create_utf8_batch(); + let int_batch = create_int_string_batch(); + let decimal_batch = create_decimal_string_batch(); let expr = Arc::new(Column::new("a", 0)); + + for (mode, mode_name) in [ + (EvalMode::Legacy, "legacy"), + (EvalMode::Ansi, "ansi"), + (EvalMode::Try, "try"), + ] { + let spark_cast_options = SparkCastOptions::new(mode, "", false); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); + + let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name)); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap()); + }); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&int_batch).unwrap()); + }); + group.finish(); + } + + // Benchmark decimal truncation (Legacy mode only) let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false); - let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); - let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); - let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); - let mut group = c.benchmark_group("cast_string_to_int"); - group.bench_function("cast_string_to_i8", |b| { - b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals"); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&decimal_batch).unwrap()); }); - group.bench_function("cast_string_to_i16", |b| { - b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); - }); - group.bench_function("cast_string_to_i32", |b| { - b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); - }); - group.bench_function("cast_string_to_i64", |b| { - b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap()); }); + group.finish(); } -// Create UTF8 batch with strings representing ints, floats, nulls -fn create_utf8_batch() -> RecordBatch { +/// Create batch with valid integer strings (works for all eval modes) +fn create_int_string_batch() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); let mut b = StringBuilder::new(); for i in 0..1000 { if i % 10 == 0 { b.append_null(); - } else if i % 2 == 0 { - b.append_value(format!("{}", rand::random::())); } else { - b.append_value(format!("{}", rand::random::())); + b.append_value(format!("{}", rand::random::())); } } let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} - RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +/// Create batch with decimal strings (for Legacy mode decimal truncation) +fn create_decimal_string_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + // Generate integers with decimal portions to test truncation + let int_part: i32 = rand::random(); + let dec_part: u32 = rand::random::() % 1000; + b.append_value(format!("{}.{}", int_part, dec_part)); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() } fn config() -> Criterion { diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 00db5a6381..41bc44d301 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -386,12 +386,13 @@ fn can_cast_from_decimal( } macro_rules! cast_utf8_to_int { - ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + ($array:expr, $array_type:ty, $parse_fn:expr) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + let parse_fn = $parse_fn; if $array.null_count() == 0 { for i in 0..len { - if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + if let Some(cast_value) = parse_fn($array.value(i))? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -401,7 +402,7 @@ macro_rules! cast_utf8_to_int { for i in 0..len { if $array.is_null(i) { cast_array.append_null() - } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + } else if let Some(cast_value) = parse_fn($array.value(i))? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -1473,22 +1474,70 @@ fn cast_string_to_int( .downcast_ref::>() .expect("cast_string_to_int expected a string array"); - let cast_array: ArrayRef = match to_type { - DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?, - DataType::Int16 => { - cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? - } - DataType::Int32 => { - cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? - } - DataType::Int64 => { - cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? - } - dt => unreachable!( - "{}", - format!("invalid integer type {dt} in cast from string") - ), - }; + // Select parse function once per batch based on eval_mode + let cast_array: ArrayRef = + match (to_type, eval_mode) { + (DataType::Int8, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_legacy)? + } + (DataType::Int8, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_ansi)? + } + (DataType::Int8, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_try)? + } + (DataType::Int16, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_legacy)? + } + (DataType::Int16, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_ansi)? + } + (DataType::Int16, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_try)? + } + (DataType::Int32, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_legacy::(s, i32::MIN) + )?, + (DataType::Int32, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int32Type, |s| do_parse_string_to_int_ansi::< + i32, + >( + s, "INT", i32::MIN + ))? + } + (DataType::Int32, EvalMode::Try) => { + cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_try::(s, i32::MIN) + )? + } + (DataType::Int64, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_legacy::(s, i64::MIN) + )?, + (DataType::Int64, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int64Type, |s| do_parse_string_to_int_ansi::< + i64, + >( + s, "BIGINT", i64::MIN + ))? + } + (DataType::Int64, EvalMode::Try) => { + cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_try::(s, i64::MIN) + )? + } + (dt, _) => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; Ok(cast_array) } @@ -1960,51 +2009,50 @@ fn spark_cast_nonintegral_numeric_to_integral( } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte -fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { - Ok(cast_string_to_int_with_range_check( - str, - eval_mode, - "TINYINT", - i8::MIN as i32, - i8::MAX as i32, - )? - .map(|v| v as i8)) +fn parse_string_to_i8_legacy(str: &str) -> SparkResult> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + None => Ok(None), + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort -fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult> { - Ok(cast_string_to_int_with_range_check( - str, - eval_mode, - "SMALLINT", - i16::MIN as i32, - i16::MAX as i32, - )? - .map(|v| v as i16)) +fn parse_string_to_i8_ansi(str: &str) -> SparkResult> { + match do_parse_string_to_int_ansi::(str, "TINYINT", i32::MIN)? { + None => Ok(None), + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Err(invalid_value(str, "STRING", "TINYINT")), + } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) -fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { - do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +fn parse_string_to_i8_try(str: &str) -> SparkResult> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + None => Ok(None), + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) -fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult> { - do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +fn parse_string_to_i16_legacy(str: &str) -> SparkResult> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + None => Ok(None), + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Ok(None), + } } -fn cast_string_to_int_with_range_check( - str: &str, - eval_mode: EvalMode, - type_name: &str, - min: i32, - max: i32, -) -> SparkResult> { - match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { +fn parse_string_to_i16_ansi(str: &str) -> SparkResult> { + match do_parse_string_to_int_ansi::(str, "SMALLINT", i32::MIN)? { None => Ok(None), - Some(v) if v >= min && v <= max => Ok(Some(v)), - _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Err(invalid_value(str, "STRING", "SMALLINT")), + } +} + +fn parse_string_to_i16_try(str: &str) -> SparkResult> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + None => Ok(None), + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), _ => Ok(None), } } @@ -2210,19 +2258,6 @@ fn do_parse_string_to_int_try + Ok(Some(result)) } -fn do_cast_string_to_int + Copy>( - str: &str, - eval_mode: EvalMode, - type_name: &str, - min_value: T, -) -> SparkResult> { - match eval_mode { - EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value), - EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value), - EvalMode::Try => do_parse_string_to_int_try(str, min_value), - } -} - fn cast_string_to_decimal( array: &ArrayRef, to_type: &DataType, From 37b90fbd5e2aad3e021f5560d27c99bbdef49ba3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 13:32:05 -0700 Subject: [PATCH 5/8] more cleanup --- .../spark-expr/src/conversion_funcs/cast.rs | 125 +++++------------- 1 file changed, 34 insertions(+), 91 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 41bc44d301..f57a69d52f 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2057,24 +2057,23 @@ fn parse_string_to_i16_try(str: &str) -> SparkResult> { } } -// Parses sign and returns (is_negative, start_idx after sign) -// Returns None if invalid (e.g., just "+" or "-") -fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> { - let len = trimmed_bytes.len(); - if len == 0 { - return None; +/// Parses sign and returns (is_negative, remaining_bytes after sign) +/// Returns None if invalid (empty input, or just "+" or "-") +fn parse_sign(bytes: &[u8]) -> Option<(bool, &[u8])> { + let (&first, rest) = bytes.split_first()?; + match first { + b'-' if !rest.is_empty() => Some((true, rest)), + b'+' if !rest.is_empty() => Some((false, rest)), + _ => Some((false, bytes)), } +} - let first_char = trimmed_bytes[0]; - let negative = first_char == b'-'; - - if negative || first_char == b'+' { - if len == 1 { - return None; - } - Some((negative, 1)) +/// Finalizes the result by applying the sign. Returns None if overflow would occur. +fn finalize_int_result(result: T, negative: bool) -> Option { + if negative { + Some(result) } else { - Some((false, 0)) + result.checked_neg().filter(|&n| n >= T::zero()) } } @@ -2087,11 +2086,7 @@ fn do_parse_string_to_int_legacy ) -> SparkResult> { let trimmed_bytes = str.as_bytes().trim_ascii(); - if trimmed_bytes.is_empty() { - return Ok(None); - } - - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, None => return Ok(None), }; @@ -2100,7 +2095,7 @@ fn do_parse_string_to_int_legacy let radix = T::from(10_u8); let stop_value = min_value / radix; - let mut iter = trimmed_bytes[idx..].iter(); + let mut iter = digits.iter(); // Parse integer portion until '.' or end for &ch in iter.by_ref() { @@ -2130,18 +2125,7 @@ fn do_parse_string_to_int_legacy } } - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Ok(None); - } - result = neg; - } else { - return Ok(None); - } - } - - Ok(Some(result)) + Ok(finalize_int_result(result, negative)) } fn do_parse_string_to_int_ansi + Copy>( @@ -2149,56 +2133,38 @@ fn do_parse_string_to_int_ansi + type_name: &str, min_value: T, ) -> SparkResult> { - let trimmed_bytes = str.as_bytes().trim_ascii(); + let error = || Err(invalid_value(str, "STRING", type_name)); - if trimmed_bytes.is_empty() { - return Err(invalid_value(str, "STRING", type_name)); - } + let trimmed_bytes = str.as_bytes().trim_ascii(); - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, - None => return Err(invalid_value(str, "STRING", type_name)), + None => return error(), }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - for &ch in &trimmed_bytes[idx..] { - if ch == b'.' { - return Err(invalid_value(str, "STRING", type_name)); - } - - if !ch.is_ascii_digit() { - return Err(invalid_value(str, "STRING", type_name)); + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { + return error(); } if result < stop_value { - return Err(invalid_value(str, "STRING", type_name)); + return error(); } let v = result * radix; let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, - _ => { - return Err(invalid_value(str, "STRING", type_name)); - } - } - } - - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Err(invalid_value(str, "STRING", type_name)); - } - result = neg; - } else { - return Err(invalid_value(str, "STRING", type_name)); + _ => return error(), } } - Ok(Some(result)) + finalize_int_result(result, negative) + .map(Some) + .ok_or_else(|| invalid_value(str, "STRING", type_name)) } fn do_parse_string_to_int_try + Copy>( @@ -2207,27 +2173,17 @@ fn do_parse_string_to_int_try + ) -> SparkResult> { let trimmed_bytes = str.as_bytes().trim_ascii(); - if trimmed_bytes.is_empty() { - return Ok(None); - } - - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, None => return Ok(None), }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - // we don't have to go beyond decimal point in try eval mode - early return NULL - for &ch in &trimmed_bytes[idx..] { - if ch == b'.' { - return Ok(None); - } - - if !ch.is_ascii_digit() { + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { return Ok(None); } @@ -2238,24 +2194,11 @@ fn do_parse_string_to_int_try + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, - _ => { - return Ok(None); - } - } - } - - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Ok(None); - } - result = neg; - } else { - return Ok(None); + _ => return Ok(None), } } - Ok(Some(result)) + Ok(finalize_int_result(result, negative)) } fn cast_string_to_decimal( From b56aa04e042b6c43d0c7eea15b5e7e81eeb8386f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Jan 2026 14:13:35 -0700 Subject: [PATCH 6/8] fix regression and update benchmark --- native/spark-expr/benches/cast_from_string.rs | 24 +++++++++++++++++++ .../spark-expr/src/conversion_funcs/cast.rs | 9 +++++++ 2 files changed, 33 insertions(+) diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 74b2c77a3d..a09afae6e1 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -23,6 +23,7 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { + let small_int_batch = create_small_int_string_batch(); let int_batch = create_int_string_batch(); let decimal_batch = create_decimal_string_batch(); let expr = Arc::new(Column::new("a", 0)); @@ -33,10 +34,18 @@ fn criterion_benchmark(c: &mut Criterion) { (EvalMode::Try, "try"), ] { let spark_cast_options = SparkCastOptions::new(mode, "", false); + let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options); let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name)); + group.bench_function("i8", |b| { + b.iter(|| cast_to_i8.evaluate(&small_int_batch).unwrap()); + }); + group.bench_function("i16", |b| { + b.iter(|| cast_to_i16.evaluate(&small_int_batch).unwrap()); + }); group.bench_function("i32", |b| { b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap()); }); @@ -61,6 +70,21 @@ fn criterion_benchmark(c: &mut Criterion) { group.finish(); } +/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks) +fn create_small_int_string_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + /// Create batch with valid integer strings (works for all eval modes) fn create_int_string_batch() -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index f57a69d52f..9b717e415d 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -3004,6 +3004,15 @@ mod tests { use super::*; + /// Test helper that wraps the mode-specific parse functions + fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { + match eval_mode { + EvalMode::Legacy => parse_string_to_i8_legacy(str), + EvalMode::Ansi => parse_string_to_i8_ansi(str), + EvalMode::Try => parse_string_to_i8_try(str), + } + } + #[test] #[cfg_attr(miri, ignore)] // test takes too long with miri fn timestamp_parser_test() { From 1510f620c50f67d6ce4279ef2e8a4fb14c993b87 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 7 Jan 2026 09:13:49 -0700 Subject: [PATCH 7/8] upmerge and improve benchmark --- .../CometDatetimeExpressionBenchmark.scala | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala index 18f292ff24..6b6cd8cfa6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala @@ -52,11 +52,20 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase { prepareTable( dir, spark.sql(s"select timestamp_micros(cast(value/100000 as integer)) as ts FROM $tbl")) - Seq("YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER").foreach { - level => - val name = s"Timestamp Truncate - $level" - val query = s"select date_trunc('$level', ts) from parquetV1Table" - runExpressionBenchmark(name, values, query) + Seq( + "YEAR", + "QUARTER", + "MONTH", + "WEEK", + "DAY", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { level => + val name = s"Timestamp Truncate - $level" + val query = s"select date_trunc('$level', ts) from parquetV1Table" + runExpressionBenchmark(name, values, query) } } } From 1eb4b399f3b2d6d81e617ef5aff99c4e8db5e3fd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 8 Jan 2026 10:02:06 -0700 Subject: [PATCH 8/8] address feedback --- native/spark-expr/src/conversion_funcs/cast.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 9b717e415d..2ff1d8c551 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2011,7 +2011,6 @@ fn spark_cast_nonintegral_numeric_to_integral( fn parse_string_to_i8_legacy(str: &str) -> SparkResult> { match do_parse_string_to_int_legacy::(str, i32::MIN)? { - None => Ok(None), Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), _ => Ok(None), } @@ -2019,7 +2018,6 @@ fn parse_string_to_i8_legacy(str: &str) -> SparkResult> { fn parse_string_to_i8_ansi(str: &str) -> SparkResult> { match do_parse_string_to_int_ansi::(str, "TINYINT", i32::MIN)? { - None => Ok(None), Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), _ => Err(invalid_value(str, "STRING", "TINYINT")), } @@ -2027,7 +2025,6 @@ fn parse_string_to_i8_ansi(str: &str) -> SparkResult> { fn parse_string_to_i8_try(str: &str) -> SparkResult> { match do_parse_string_to_int_try::(str, i32::MIN)? { - None => Ok(None), Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), _ => Ok(None), } @@ -2035,7 +2032,6 @@ fn parse_string_to_i8_try(str: &str) -> SparkResult> { fn parse_string_to_i16_legacy(str: &str) -> SparkResult> { match do_parse_string_to_int_legacy::(str, i32::MIN)? { - None => Ok(None), Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), _ => Ok(None), } @@ -2043,7 +2039,6 @@ fn parse_string_to_i16_legacy(str: &str) -> SparkResult> { fn parse_string_to_i16_ansi(str: &str) -> SparkResult> { match do_parse_string_to_int_ansi::(str, "SMALLINT", i32::MIN)? { - None => Ok(None), Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), _ => Err(invalid_value(str, "STRING", "SMALLINT")), } @@ -2051,7 +2046,6 @@ fn parse_string_to_i16_ansi(str: &str) -> SparkResult> { fn parse_string_to_i16_try(str: &str) -> SparkResult> { match do_parse_string_to_int_try::(str, i32::MIN)? { - None => Ok(None), Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), _ => Ok(None), }