Skip to content

Commit 069681a

Browse files
authored
perf: Improve performance of CAST from string to int (apache#3017)
1 parent 092d88c commit 069681a

File tree

1 file changed

+217
-69
lines changed
  • native/spark-expr/src/conversion_funcs

1 file changed

+217
-69
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 217 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ use datafusion::common::{
5454
use datafusion::physical_expr::PhysicalExpr;
5555
use datafusion::physical_plan::ColumnarValue;
5656
use num::{
57-
cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
58-
ToPrimitive, Zero,
57+
cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive,
58+
Zero,
5959
};
6060
use regex::Regex;
6161
use std::str::FromStr;
@@ -389,13 +389,23 @@ macro_rules! cast_utf8_to_int {
389389
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
390390
let len = $array.len();
391391
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
392-
for i in 0..len {
393-
if $array.is_null(i) {
394-
cast_array.append_null()
395-
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
396-
cast_array.append_value(cast_value);
397-
} else {
398-
cast_array.append_null()
392+
if $array.null_count() == 0 {
393+
for i in 0..len {
394+
if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
395+
cast_array.append_value(cast_value);
396+
} else {
397+
cast_array.append_null()
398+
}
399+
}
400+
} else {
401+
for i in 0..len {
402+
if $array.is_null(i) {
403+
cast_array.append_null()
404+
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
405+
cast_array.append_value(cast_value);
406+
} else {
407+
cast_array.append_null()
408+
}
399409
}
400410
}
401411
let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
@@ -1999,100 +2009,247 @@ fn cast_string_to_int_with_range_check(
19992009
}
20002010
}
20012011

2012+
// Returns (start, end) indices after trimming whitespace
2013+
fn trim_whitespace(bytes: &[u8]) -> (usize, usize) {
2014+
let mut start = 0;
2015+
let mut end = bytes.len();
2016+
2017+
while start < end && bytes[start].is_ascii_whitespace() {
2018+
start += 1;
2019+
}
2020+
while end > start && bytes[end - 1].is_ascii_whitespace() {
2021+
end -= 1;
2022+
}
2023+
2024+
(start, end)
2025+
}
2026+
2027+
// Parses sign and returns (is_negative, start_idx after sign)
2028+
// Returns None if invalid (e.g., just "+" or "-")
2029+
fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> {
2030+
let len = trimmed_bytes.len();
2031+
if len == 0 {
2032+
return None;
2033+
}
2034+
2035+
let first_char = trimmed_bytes[0];
2036+
let negative = first_char == b'-';
2037+
2038+
if negative || first_char == b'+' {
2039+
if len == 1 {
2040+
return None;
2041+
}
2042+
Some((negative, 1))
2043+
} else {
2044+
Some((false, 0))
2045+
}
2046+
}
2047+
20022048
/// Equivalent to
20032049
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
20042050
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
2005-
fn do_cast_string_to_int<
2006-
T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
2007-
>(
2051+
fn do_parse_string_to_int_legacy<T: Integer + CheckedSub + CheckedNeg + From<u8> + Copy>(
20082052
str: &str,
2009-
eval_mode: EvalMode,
2010-
type_name: &str,
20112053
min_value: T,
20122054
) -> SparkResult<Option<T>> {
2013-
let trimmed_str = str.trim();
2014-
if trimmed_str.is_empty() {
2015-
return none_or_err(eval_mode, type_name, str);
2055+
let bytes = str.as_bytes();
2056+
let (start, end) = trim_whitespace(bytes);
2057+
2058+
if start == end {
2059+
return Ok(None);
20162060
}
2017-
let len = trimmed_str.len();
2061+
let trimmed_bytes = &bytes[start..end];
2062+
2063+
let (negative, idx) = match parse_sign(trimmed_bytes) {
2064+
Some(result) => result,
2065+
None => return Ok(None),
2066+
};
2067+
20182068
let mut result: T = T::zero();
2019-
let mut negative = false;
2020-
let radix = T::from(10);
2069+
2070+
let radix = T::from(10_u8);
20212071
let stop_value = min_value / radix;
20222072
let mut parse_sign_and_digits = true;
20232073

2024-
for (i, ch) in trimmed_str.char_indices() {
2074+
for &ch in &trimmed_bytes[idx..] {
20252075
if parse_sign_and_digits {
2026-
if i == 0 {
2027-
negative = ch == '-';
2028-
let positive = ch == '+';
2029-
if negative || positive {
2030-
if i + 1 == len {
2031-
// input string is just "+" or "-"
2032-
return none_or_err(eval_mode, type_name, str);
2033-
}
2034-
// consume this char
2035-
continue;
2036-
}
2076+
if ch == b'.' {
2077+
// truncate decimal in legacy mode
2078+
parse_sign_and_digits = false;
2079+
continue;
20372080
}
20382081

2039-
if ch == '.' {
2040-
if eval_mode == EvalMode::Legacy {
2041-
// truncate decimal in legacy mode
2042-
parse_sign_and_digits = false;
2043-
continue;
2044-
} else {
2045-
return none_or_err(eval_mode, type_name, str);
2046-
}
2082+
if !ch.is_ascii_digit() {
2083+
return Ok(None);
20472084
}
20482085

2049-
let digit = if ch.is_ascii_digit() {
2050-
(ch as u32) - ('0' as u32)
2051-
} else {
2052-
return none_or_err(eval_mode, type_name, str);
2053-
};
2086+
let digit: T = T::from(ch - b'0');
20542087

2055-
// We are going to process the new digit and accumulate the result. However, before
2056-
// doing this, if the result is already smaller than the
2057-
// stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be
2058-
// smaller than minValue, and we can stop
20592088
if result < stop_value {
2060-
return none_or_err(eval_mode, type_name, str);
2089+
return Ok(None);
20612090
}
2062-
2063-
// Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE /
2064-
// radix), we can just use `result > 0` to check overflow. If result
2065-
// overflows, we should stop
20662091
let v = result * radix;
2067-
let digit = (digit as i32).into();
20682092
match v.checked_sub(&digit) {
20692093
Some(x) if x <= T::zero() => result = x,
20702094
_ => {
2071-
return none_or_err(eval_mode, type_name, str);
2095+
return Ok(None);
20722096
}
20732097
}
20742098
} else {
2075-
// make sure fractional digits are valid digits but ignore them
2099+
// in legacy mode we still process chars after the dot and make sure the chars are digits
20762100
if !ch.is_ascii_digit() {
2077-
return none_or_err(eval_mode, type_name, str);
2101+
return Ok(None);
2102+
}
2103+
}
2104+
}
2105+
2106+
if !negative {
2107+
if let Some(neg) = result.checked_neg() {
2108+
if neg < T::zero() {
2109+
return Ok(None);
2110+
}
2111+
result = neg;
2112+
} else {
2113+
return Ok(None);
2114+
}
2115+
}
2116+
2117+
Ok(Some(result))
2118+
}
2119+
2120+
fn do_parse_string_to_int_ansi<T: Integer + CheckedSub + CheckedNeg + From<u8> + Copy>(
2121+
str: &str,
2122+
type_name: &str,
2123+
min_value: T,
2124+
) -> SparkResult<Option<T>> {
2125+
let bytes = str.as_bytes();
2126+
let (start, end) = trim_whitespace(bytes);
2127+
2128+
if start == end {
2129+
return Err(invalid_value(str, "STRING", type_name));
2130+
}
2131+
let trimmed_bytes = &bytes[start..end];
2132+
2133+
let (negative, idx) = match parse_sign(trimmed_bytes) {
2134+
Some(result) => result,
2135+
None => return Err(invalid_value(str, "STRING", type_name)),
2136+
};
2137+
2138+
let mut result: T = T::zero();
2139+
2140+
let radix = T::from(10_u8);
2141+
let stop_value = min_value / radix;
2142+
2143+
for &ch in &trimmed_bytes[idx..] {
2144+
if ch == b'.' {
2145+
return Err(invalid_value(str, "STRING", type_name));
2146+
}
2147+
2148+
if !ch.is_ascii_digit() {
2149+
return Err(invalid_value(str, "STRING", type_name));
2150+
}
2151+
2152+
let digit: T = T::from(ch - b'0');
2153+
2154+
if result < stop_value {
2155+
return Err(invalid_value(str, "STRING", type_name));
2156+
}
2157+
let v = result * radix;
2158+
match v.checked_sub(&digit) {
2159+
Some(x) if x <= T::zero() => result = x,
2160+
_ => {
2161+
return Err(invalid_value(str, "STRING", type_name));
2162+
}
2163+
}
2164+
}
2165+
2166+
if !negative {
2167+
if let Some(neg) = result.checked_neg() {
2168+
if neg < T::zero() {
2169+
return Err(invalid_value(str, "STRING", type_name));
2170+
}
2171+
result = neg;
2172+
} else {
2173+
return Err(invalid_value(str, "STRING", type_name));
2174+
}
2175+
}
2176+
2177+
Ok(Some(result))
2178+
}
2179+
2180+
fn do_parse_string_to_int_try<T: Integer + CheckedSub + CheckedNeg + From<u8> + Copy>(
2181+
str: &str,
2182+
min_value: T,
2183+
) -> SparkResult<Option<T>> {
2184+
let bytes = str.as_bytes();
2185+
let (start, end) = trim_whitespace(bytes);
2186+
2187+
if start == end {
2188+
return Ok(None);
2189+
}
2190+
let trimmed_bytes = &bytes[start..end];
2191+
2192+
let (negative, idx) = match parse_sign(trimmed_bytes) {
2193+
Some(result) => result,
2194+
None => return Ok(None),
2195+
};
2196+
2197+
let mut result: T = T::zero();
2198+
2199+
let radix = T::from(10_u8);
2200+
let stop_value = min_value / radix;
2201+
2202+
// we don't have to go beyond decimal point in try eval mode - early return NULL
2203+
for &ch in &trimmed_bytes[idx..] {
2204+
if ch == b'.' {
2205+
return Ok(None);
2206+
}
2207+
2208+
if !ch.is_ascii_digit() {
2209+
return Ok(None);
2210+
}
2211+
2212+
let digit: T = T::from(ch - b'0');
2213+
2214+
if result < stop_value {
2215+
return Ok(None);
2216+
}
2217+
let v = result * radix;
2218+
match v.checked_sub(&digit) {
2219+
Some(x) if x <= T::zero() => result = x,
2220+
_ => {
2221+
return Ok(None);
20782222
}
20792223
}
20802224
}
20812225

20822226
if !negative {
20832227
if let Some(neg) = result.checked_neg() {
20842228
if neg < T::zero() {
2085-
return none_or_err(eval_mode, type_name, str);
2229+
return Ok(None);
20862230
}
20872231
result = neg;
20882232
} else {
2089-
return none_or_err(eval_mode, type_name, str);
2233+
return Ok(None);
20902234
}
20912235
}
20922236

20932237
Ok(Some(result))
20942238
}
20952239

2240+
fn do_cast_string_to_int<T: Integer + CheckedSub + CheckedNeg + From<u8> + Copy>(
2241+
str: &str,
2242+
eval_mode: EvalMode,
2243+
type_name: &str,
2244+
min_value: T,
2245+
) -> SparkResult<Option<T>> {
2246+
match eval_mode {
2247+
EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value),
2248+
EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value),
2249+
EvalMode::Try => do_parse_string_to_int_try(str, min_value),
2250+
}
2251+
}
2252+
20962253
fn cast_string_to_decimal(
20972254
array: &ArrayRef,
20982255
to_type: &DataType,
@@ -2393,15 +2550,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
23932550
Ok((final_mantissa, final_scale))
23942551
}
23952552

2396-
/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
2397-
#[inline]
2398-
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
2399-
match eval_mode {
2400-
EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
2401-
_ => Ok(None),
2402-
}
2403-
}
2404-
24052553
#[inline]
24062554
fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
24072555
SparkError::CastInvalidValue {

0 commit comments

Comments
 (0)