Skip to content

Commit ca90587

Browse files
committed
perf_string_to_int
1 parent 768a017 commit ca90587

File tree

1 file changed

+52
-10
lines changed
  • native/spark-expr/src/conversion_funcs

1 file changed

+52
-10
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -389,19 +389,32 @@ macro_rules! cast_utf8_to_int {
389389
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
390390
let len = $array.len();
391391
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
392-
for i in 0..len {
393-
if $array.is_null(i) {
394-
cast_array.append_null()
395-
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
396-
cast_array.append_value(cast_value);
397-
} else {
398-
cast_array.append_null()
392+
393+
if $array.null_count() == 0 {
394+
for i in 0..len {
395+
if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
396+
cast_array.append_value(cast_value);
397+
} else {
398+
cast_array.append_null()
399+
}
400+
}
401+
} else {
402+
for i in 0..len {
403+
if $array.is_null(i) {
404+
cast_array.append_null()
405+
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
406+
cast_array.append_value(cast_value);
407+
} else {
408+
cast_array.append_null()
409+
}
399410
}
400411
}
412+
401413
let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
402414
result
403415
}};
404416
}
417+
405418
macro_rules! cast_utf8_to_timestamp {
406419
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
407420
let len = $array.len();
@@ -1931,6 +1944,35 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>
19311944

19321945
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
19331946
fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1947+
// happy path
1948+
let bytes = str.as_bytes();
1949+
let len = bytes.len();
1950+
if len > 0 && len <= 10 {
1951+
// SAFETY: We checked len > 0 above
1952+
let first = unsafe { *bytes.get_unchecked(0) };
1953+
// Must start with digit for happy path
1954+
if first >= b'0' && first <= b'9' {
1955+
let mut result: i64 = (first - b'0') as i64;
1956+
let mut i = 1;
1957+
1958+
// Try to parse remaining digits
1959+
while i < len {
1960+
let b = bytes[i];
1961+
if b >= b'0' && b <= b'9' {
1962+
result = result * 10 + (b - b'0') as i64;
1963+
i += 1;
1964+
} else {
1965+
// Hit non-digit (space, sign, decimal, etc.) - Bail to slow path
1966+
break;
1967+
}
1968+
}
1969+
if i == len && result <= i32::MAX as i64 {
1970+
return Ok(Some(result as i32));
1971+
}
1972+
// Otherwise fall through to slow path
1973+
}
1974+
}
1975+
19341976
do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
19351977
}
19361978

@@ -1965,7 +2007,6 @@ fn do_cast_string_to_int<
19652007
type_name: &str,
19662008
min_value: T,
19672009
) -> SparkResult<Option<T>> {
1968-
19692010
let bytes = str.as_bytes();
19702011
let mut start = 0;
19712012
let mut end = bytes.len();
@@ -1989,7 +2030,7 @@ fn do_cast_string_to_int<
19892030
let negative = first_char == b'-';
19902031
if negative || first_char == b'+' {
19912032
idx = 1;
1992-
if len == 1{
2033+
if len == 1 {
19932034
return none_or_err(eval_mode, type_name, str);
19942035
}
19952036
}
@@ -1998,7 +2039,7 @@ fn do_cast_string_to_int<
19982039
let stop_value = min_value / radix;
19992040
let mut parse_sign_and_digits = true;
20002041

2001-
for &ch in &trimmed_bytes[idx..] {
2042+
for &ch in &trimmed_bytes[idx..] {
20022043
if parse_sign_and_digits {
20032044
if ch == b'.' {
20042045
if eval_mode == EvalMode::Legacy {
@@ -2014,6 +2055,7 @@ fn do_cast_string_to_int<
20142055
return none_or_err(eval_mode, type_name, str);
20152056
}
20162057
let digit = T::from((ch - b'0') as i32);
2058+
result = (result << 3) + (result << 1) - digit;
20172059
result = result * radix - digit;
20182060

20192061
// We are going to process the new digit and accumulate the result. However, before

0 commit comments

Comments
 (0)