diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 314beb18ca..2ff1d8c551 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -386,12 +386,13 @@ fn can_cast_from_decimal( } macro_rules! cast_utf8_to_int { - ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + ($array:expr, $array_type:ty, $parse_fn:expr) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + let parse_fn = $parse_fn; if $array.null_count() == 0 { for i in 0..len { - if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + if let Some(cast_value) = parse_fn($array.value(i))? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -401,7 +402,7 @@ macro_rules! cast_utf8_to_int { for i in 0..len { if $array.is_null(i) { cast_array.append_null() - } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + } else if let Some(cast_value) = parse_fn($array.value(i))? { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -1473,22 +1474,70 @@ fn cast_string_to_int( .downcast_ref::>() .expect("cast_string_to_int expected a string array"); - let cast_array: ArrayRef = match to_type { - DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?, - DataType::Int16 => { - cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? - } - DataType::Int32 => { - cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? - } - DataType::Int64 => { - cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? - } - dt => unreachable!( - "{}", - format!("invalid integer type {dt} in cast from string") - ), - }; + // Select parse function once per batch based on eval_mode + let cast_array: ArrayRef = + match (to_type, eval_mode) { + (DataType::Int8, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_legacy)? + } + (DataType::Int8, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_ansi)? + } + (DataType::Int8, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_try)? + } + (DataType::Int16, EvalMode::Legacy) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_legacy)? + } + (DataType::Int16, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_ansi)? + } + (DataType::Int16, EvalMode::Try) => { + cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_try)? + } + (DataType::Int32, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_legacy::(s, i32::MIN) + )?, + (DataType::Int32, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int32Type, |s| do_parse_string_to_int_ansi::< + i32, + >( + s, "INT", i32::MIN + ))? + } + (DataType::Int32, EvalMode::Try) => { + cast_utf8_to_int!( + string_array, + Int32Type, + |s| do_parse_string_to_int_try::(s, i32::MIN) + )? + } + (DataType::Int64, EvalMode::Legacy) => cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_legacy::(s, i64::MIN) + )?, + (DataType::Int64, EvalMode::Ansi) => { + cast_utf8_to_int!(string_array, Int64Type, |s| do_parse_string_to_int_ansi::< + i64, + >( + s, "BIGINT", i64::MIN + ))? + } + (DataType::Int64, EvalMode::Try) => { + cast_utf8_to_int!( + string_array, + Int64Type, + |s| do_parse_string_to_int_try::(s, i64::MIN) + )? + } + (dt, _) => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; Ok(cast_array) } @@ -1960,88 +2009,65 @@ fn spark_cast_nonintegral_numeric_to_integral( } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte -fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { - Ok(cast_string_to_int_with_range_check( - str, - eval_mode, - "TINYINT", - i8::MIN as i32, - i8::MAX as i32, - )? - .map(|v| v as i8)) -} - -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort -fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult> { - Ok(cast_string_to_int_with_range_check( - str, - eval_mode, - "SMALLINT", - i16::MIN as i32, - i16::MAX as i32, - )? - .map(|v| v as i16)) +fn parse_string_to_i8_legacy(str: &str) -> SparkResult> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) -fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { - do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +fn parse_string_to_i8_ansi(str: &str) -> SparkResult> { + match do_parse_string_to_int_ansi::(str, "TINYINT", i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Err(invalid_value(str, "STRING", "TINYINT")), + } } -/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) -fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult> { - do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +fn parse_string_to_i8_try(str: &str) -> SparkResult> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)), + _ => Ok(None), + } } -fn cast_string_to_int_with_range_check( - str: &str, - eval_mode: EvalMode, - type_name: &str, - min: i32, - max: i32, -) -> SparkResult> { - match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { - None => Ok(None), - Some(v) if v >= min && v <= max => Ok(Some(v)), - _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), +fn parse_string_to_i16_legacy(str: &str) -> SparkResult> { + match do_parse_string_to_int_legacy::(str, i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), _ => Ok(None), } } -// Returns (start, end) indices after trimming whitespace -fn trim_whitespace(bytes: &[u8]) -> (usize, usize) { - let mut start = 0; - let mut end = bytes.len(); - - while start < end && bytes[start].is_ascii_whitespace() { - start += 1; +fn parse_string_to_i16_ansi(str: &str) -> SparkResult> { + match do_parse_string_to_int_ansi::(str, "SMALLINT", i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Err(invalid_value(str, "STRING", "SMALLINT")), } - while end > start && bytes[end - 1].is_ascii_whitespace() { - end -= 1; - } - - (start, end) } -// Parses sign and returns (is_negative, start_idx after sign) -// Returns None if invalid (e.g., just "+" or "-") -fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> { - let len = trimmed_bytes.len(); - if len == 0 { - return None; +fn parse_string_to_i16_try(str: &str) -> SparkResult> { + match do_parse_string_to_int_try::(str, i32::MIN)? { + Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)), + _ => Ok(None), } +} - let first_char = trimmed_bytes[0]; - let negative = first_char == b'-'; +/// Parses sign and returns (is_negative, remaining_bytes after sign) +/// Returns None if invalid (empty input, or just "+" or "-") +fn parse_sign(bytes: &[u8]) -> Option<(bool, &[u8])> { + let (&first, rest) = bytes.split_first()?; + match first { + b'-' if !rest.is_empty() => Some((true, rest)), + b'+' if !rest.is_empty() => Some((false, rest)), + _ => Some((false, bytes)), + } +} - if negative || first_char == b'+' { - if len == 1 { - return None; - } - Some((negative, 1)) +/// Finalizes the result by applying the sign. Returns None if overflow would occur. +fn finalize_int_result(result: T, negative: bool) -> Option { + if negative { + Some(result) } else { - Some((false, 0)) + result.checked_neg().filter(|&n| n >= T::zero()) } } @@ -2052,69 +2078,48 @@ fn do_parse_string_to_int_legacy str: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); - - if start == end { - return Ok(None); - } - let trimmed_bytes = &bytes[start..end]; + let trimmed_bytes = str.as_bytes().trim_ascii(); - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, None => return Ok(None), }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - let mut parse_sign_and_digits = true; - for &ch in &trimmed_bytes[idx..] { - if parse_sign_and_digits { - if ch == b'.' { - // truncate decimal in legacy mode - parse_sign_and_digits = false; - continue; - } + let mut iter = digits.iter(); - if !ch.is_ascii_digit() { - return Ok(None); - } + // Parse integer portion until '.' or end + for &ch in iter.by_ref() { + if ch == b'.' { + break; + } - let digit: T = T::from(ch - b'0'); + if !ch.is_ascii_digit() { + return Ok(None); + } - if result < stop_value { - return Ok(None); - } - let v = result * radix; - match v.checked_sub(&digit) { - Some(x) if x <= T::zero() => result = x, - _ => { - return Ok(None); - } - } - } else { - // in legacy mode we still process chars after the dot and make sure the chars are digits - if !ch.is_ascii_digit() { - return Ok(None); - } + if result < stop_value { + return Ok(None); + } + let v = result * radix; + let digit: T = T::from(ch - b'0'); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => return Ok(None), } } - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Ok(None); - } - result = neg; - } else { + // Validate decimal portion (digits only, values ignored) + for &ch in iter { + if !ch.is_ascii_digit() { return Ok(None); } } - Ok(Some(result)) + Ok(finalize_int_result(result, negative)) } fn do_parse_string_to_int_ansi + Copy>( @@ -2122,132 +2127,72 @@ fn do_parse_string_to_int_ansi + type_name: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); + let error = || Err(invalid_value(str, "STRING", type_name)); - if start == end { - return Err(invalid_value(str, "STRING", type_name)); - } - let trimmed_bytes = &bytes[start..end]; + let trimmed_bytes = str.as_bytes().trim_ascii(); - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, - None => return Err(invalid_value(str, "STRING", type_name)), + None => return error(), }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - for &ch in &trimmed_bytes[idx..] { - if ch == b'.' { - return Err(invalid_value(str, "STRING", type_name)); - } - - if !ch.is_ascii_digit() { - return Err(invalid_value(str, "STRING", type_name)); + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { + return error(); } - let digit: T = T::from(ch - b'0'); - if result < stop_value { - return Err(invalid_value(str, "STRING", type_name)); + return error(); } let v = result * radix; + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, - _ => { - return Err(invalid_value(str, "STRING", type_name)); - } + _ => return error(), } } - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Err(invalid_value(str, "STRING", type_name)); - } - result = neg; - } else { - return Err(invalid_value(str, "STRING", type_name)); - } - } - - Ok(Some(result)) + finalize_int_result(result, negative) + .map(Some) + .ok_or_else(|| invalid_value(str, "STRING", type_name)) } fn do_parse_string_to_int_try + Copy>( str: &str, min_value: T, ) -> SparkResult> { - let bytes = str.as_bytes(); - let (start, end) = trim_whitespace(bytes); - - if start == end { - return Ok(None); - } - let trimmed_bytes = &bytes[start..end]; + let trimmed_bytes = str.as_bytes().trim_ascii(); - let (negative, idx) = match parse_sign(trimmed_bytes) { + let (negative, digits) = match parse_sign(trimmed_bytes) { Some(result) => result, None => return Ok(None), }; let mut result: T = T::zero(); - let radix = T::from(10_u8); let stop_value = min_value / radix; - // we don't have to go beyond decimal point in try eval mode - early return NULL - for &ch in &trimmed_bytes[idx..] { - if ch == b'.' { - return Ok(None); - } - - if !ch.is_ascii_digit() { + for &ch in digits { + if ch == b'.' || !ch.is_ascii_digit() { return Ok(None); } - let digit: T = T::from(ch - b'0'); - if result < stop_value { return Ok(None); } let v = result * radix; + let digit: T = T::from(ch - b'0'); match v.checked_sub(&digit) { Some(x) if x <= T::zero() => result = x, - _ => { - return Ok(None); - } - } - } - - if !negative { - if let Some(neg) = result.checked_neg() { - if neg < T::zero() { - return Ok(None); - } - result = neg; - } else { - return Ok(None); + _ => return Ok(None), } } - Ok(Some(result)) -} - -fn do_cast_string_to_int + Copy>( - str: &str, - eval_mode: EvalMode, - type_name: &str, - min_value: T, -) -> SparkResult> { - match eval_mode { - EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value), - EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value), - EvalMode::Try => do_parse_string_to_int_try(str, min_value), - } + Ok(finalize_int_result(result, negative)) } fn cast_string_to_decimal( @@ -3053,6 +2998,15 @@ mod tests { use super::*; + /// Test helper that wraps the mode-specific parse functions + fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { + match eval_mode { + EvalMode::Legacy => parse_string_to_i8_legacy(str), + EvalMode::Ansi => parse_string_to_i8_ansi(str), + EvalMode::Try => parse_string_to_i8_try(str), + } + } + #[test] #[cfg_attr(miri, ignore)] // test takes too long with miri fn timestamp_parser_test() { diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala index 18f292ff24..6b6cd8cfa6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometDatetimeExpressionBenchmark.scala @@ -52,11 +52,20 @@ object CometDatetimeExpressionBenchmark extends CometBenchmarkBase { prepareTable( dir, spark.sql(s"select timestamp_micros(cast(value/100000 as integer)) as ts FROM $tbl")) - Seq("YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER").foreach { - level => - val name = s"Timestamp Truncate - $level" - val query = s"select date_trunc('$level', ts) from parquetV1Table" - runExpressionBenchmark(name, values, query) + Seq( + "YEAR", + "QUARTER", + "MONTH", + "WEEK", + "DAY", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { level => + val name = s"Timestamp Truncate - $level" + val query = s"select date_trunc('$level', ts) from parquetV1Table" + runExpressionBenchmark(name, values, query) } } }