@@ -54,8 +54,8 @@ use datafusion::common::{
5454use datafusion:: physical_expr:: PhysicalExpr ;
5555use datafusion:: physical_plan:: ColumnarValue ;
5656use num:: {
57- cast:: AsPrimitive , integer:: div_floor, traits:: CheckedNeg , CheckedSub , Integer , Num ,
58- ToPrimitive , Zero ,
57+ cast:: AsPrimitive , integer:: div_floor, traits:: CheckedNeg , CheckedSub , Integer , ToPrimitive ,
58+ Zero ,
5959} ;
6060use regex:: Regex ;
6161use std:: str:: FromStr ;
@@ -389,13 +389,23 @@ macro_rules! cast_utf8_to_int {
389389 ( $array: expr, $eval_mode: expr, $array_type: ty, $cast_method: ident) => { {
390390 let len = $array. len( ) ;
391391 let mut cast_array = PrimitiveArray :: <$array_type>:: builder( len) ;
392- for i in 0 ..len {
393- if $array. is_null( i) {
394- cast_array. append_null( )
395- } else if let Some ( cast_value) = $cast_method( $array. value( i) , $eval_mode) ? {
396- cast_array. append_value( cast_value) ;
397- } else {
398- cast_array. append_null( )
392+ if $array. null_count( ) == 0 {
393+ for i in 0 ..len {
394+ if let Some ( cast_value) = $cast_method( $array. value( i) , $eval_mode) ? {
395+ cast_array. append_value( cast_value) ;
396+ } else {
397+ cast_array. append_null( )
398+ }
399+ }
400+ } else {
401+ for i in 0 ..len {
402+ if $array. is_null( i) {
403+ cast_array. append_null( )
404+ } else if let Some ( cast_value) = $cast_method( $array. value( i) , $eval_mode) ? {
405+ cast_array. append_value( cast_value) ;
406+ } else {
407+ cast_array. append_null( )
408+ }
399409 }
400410 }
401411 let result: SparkResult <ArrayRef > = Ok ( Arc :: new( cast_array. finish( ) ) as ArrayRef ) ;
@@ -1999,100 +2009,247 @@ fn cast_string_to_int_with_range_check(
19992009 }
20002010}
20012011
2012+ // Returns (start, end) indices after trimming whitespace
2013+ fn trim_whitespace ( bytes : & [ u8 ] ) -> ( usize , usize ) {
2014+ let mut start = 0 ;
2015+ let mut end = bytes. len ( ) ;
2016+
2017+ while start < end && bytes[ start] . is_ascii_whitespace ( ) {
2018+ start += 1 ;
2019+ }
2020+ while end > start && bytes[ end - 1 ] . is_ascii_whitespace ( ) {
2021+ end -= 1 ;
2022+ }
2023+
2024+ ( start, end)
2025+ }
2026+
2027+ // Parses sign and returns (is_negative, start_idx after sign)
2028+ // Returns None if invalid (e.g., just "+" or "-")
2029+ fn parse_sign ( trimmed_bytes : & [ u8 ] ) -> Option < ( bool , usize ) > {
2030+ let len = trimmed_bytes. len ( ) ;
2031+ if len == 0 {
2032+ return None ;
2033+ }
2034+
2035+ let first_char = trimmed_bytes[ 0 ] ;
2036+ let negative = first_char == b'-' ;
2037+
2038+ if negative || first_char == b'+' {
2039+ if len == 1 {
2040+ return None ;
2041+ }
2042+ Some ( ( negative, 1 ) )
2043+ } else {
2044+ Some ( ( false , 0 ) )
2045+ }
2046+ }
2047+
20022048/// Equivalent to
20032049/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal)
20042050/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal)
2005- fn do_cast_string_to_int <
2006- T : Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From < i32 > + Copy ,
2007- > (
2051+ fn do_parse_string_to_int_legacy < T : Integer + CheckedSub + CheckedNeg + From < u8 > + Copy > (
20082052 str : & str ,
2009- eval_mode : EvalMode ,
2010- type_name : & str ,
20112053 min_value : T ,
20122054) -> SparkResult < Option < T > > {
2013- let trimmed_str = str. trim ( ) ;
2014- if trimmed_str. is_empty ( ) {
2015- return none_or_err ( eval_mode, type_name, str) ;
2055+ let bytes = str. as_bytes ( ) ;
2056+ let ( start, end) = trim_whitespace ( bytes) ;
2057+
2058+ if start == end {
2059+ return Ok ( None ) ;
20162060 }
2017- let len = trimmed_str. len ( ) ;
2061+ let trimmed_bytes = & bytes[ start..end] ;
2062+
2063+ let ( negative, idx) = match parse_sign ( trimmed_bytes) {
2064+ Some ( result) => result,
2065+ None => return Ok ( None ) ,
2066+ } ;
2067+
20182068 let mut result: T = T :: zero ( ) ;
2019- let mut negative = false ;
2020- let radix = T :: from ( 10 ) ;
2069+
2070+ let radix = T :: from ( 10_u8 ) ;
20212071 let stop_value = min_value / radix;
20222072 let mut parse_sign_and_digits = true ;
20232073
2024- for ( i , ch ) in trimmed_str . char_indices ( ) {
2074+ for & ch in & trimmed_bytes [ idx.. ] {
20252075 if parse_sign_and_digits {
2026- if i == 0 {
2027- negative = ch == '-' ;
2028- let positive = ch == '+' ;
2029- if negative || positive {
2030- if i + 1 == len {
2031- // input string is just "+" or "-"
2032- return none_or_err ( eval_mode, type_name, str) ;
2033- }
2034- // consume this char
2035- continue ;
2036- }
2076+ if ch == b'.' {
2077+ // truncate decimal in legacy mode
2078+ parse_sign_and_digits = false ;
2079+ continue ;
20372080 }
20382081
2039- if ch == '.' {
2040- if eval_mode == EvalMode :: Legacy {
2041- // truncate decimal in legacy mode
2042- parse_sign_and_digits = false ;
2043- continue ;
2044- } else {
2045- return none_or_err ( eval_mode, type_name, str) ;
2046- }
2082+ if !ch. is_ascii_digit ( ) {
2083+ return Ok ( None ) ;
20472084 }
20482085
2049- let digit = if ch. is_ascii_digit ( ) {
2050- ( ch as u32 ) - ( '0' as u32 )
2051- } else {
2052- return none_or_err ( eval_mode, type_name, str) ;
2053- } ;
2086+ let digit: T = T :: from ( ch - b'0' ) ;
20542087
2055- // We are going to process the new digit and accumulate the result. However, before
2056- // doing this, if the result is already smaller than the
2057- // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be
2058- // smaller than minValue, and we can stop
20592088 if result < stop_value {
2060- return none_or_err ( eval_mode , type_name , str ) ;
2089+ return Ok ( None ) ;
20612090 }
2062-
2063- // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE /
2064- // radix), we can just use `result > 0` to check overflow. If result
2065- // overflows, we should stop
20662091 let v = result * radix;
2067- let digit = ( digit as i32 ) . into ( ) ;
20682092 match v. checked_sub ( & digit) {
20692093 Some ( x) if x <= T :: zero ( ) => result = x,
20702094 _ => {
2071- return none_or_err ( eval_mode , type_name , str ) ;
2095+ return Ok ( None ) ;
20722096 }
20732097 }
20742098 } else {
2075- // make sure fractional digits are valid digits but ignore them
2099+ // in legacy mode we still process chars after the dot and make sure the chars are digits
20762100 if !ch. is_ascii_digit ( ) {
2077- return none_or_err ( eval_mode, type_name, str) ;
2101+ return Ok ( None ) ;
2102+ }
2103+ }
2104+ }
2105+
2106+ if !negative {
2107+ if let Some ( neg) = result. checked_neg ( ) {
2108+ if neg < T :: zero ( ) {
2109+ return Ok ( None ) ;
2110+ }
2111+ result = neg;
2112+ } else {
2113+ return Ok ( None ) ;
2114+ }
2115+ }
2116+
2117+ Ok ( Some ( result) )
2118+ }
2119+
2120+ fn do_parse_string_to_int_ansi < T : Integer + CheckedSub + CheckedNeg + From < u8 > + Copy > (
2121+ str : & str ,
2122+ type_name : & str ,
2123+ min_value : T ,
2124+ ) -> SparkResult < Option < T > > {
2125+ let bytes = str. as_bytes ( ) ;
2126+ let ( start, end) = trim_whitespace ( bytes) ;
2127+
2128+ if start == end {
2129+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2130+ }
2131+ let trimmed_bytes = & bytes[ start..end] ;
2132+
2133+ let ( negative, idx) = match parse_sign ( trimmed_bytes) {
2134+ Some ( result) => result,
2135+ None => return Err ( invalid_value ( str, "STRING" , type_name) ) ,
2136+ } ;
2137+
2138+ let mut result: T = T :: zero ( ) ;
2139+
2140+ let radix = T :: from ( 10_u8 ) ;
2141+ let stop_value = min_value / radix;
2142+
2143+ for & ch in & trimmed_bytes[ idx..] {
2144+ if ch == b'.' {
2145+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2146+ }
2147+
2148+ if !ch. is_ascii_digit ( ) {
2149+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2150+ }
2151+
2152+ let digit: T = T :: from ( ch - b'0' ) ;
2153+
2154+ if result < stop_value {
2155+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2156+ }
2157+ let v = result * radix;
2158+ match v. checked_sub ( & digit) {
2159+ Some ( x) if x <= T :: zero ( ) => result = x,
2160+ _ => {
2161+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2162+ }
2163+ }
2164+ }
2165+
2166+ if !negative {
2167+ if let Some ( neg) = result. checked_neg ( ) {
2168+ if neg < T :: zero ( ) {
2169+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2170+ }
2171+ result = neg;
2172+ } else {
2173+ return Err ( invalid_value ( str, "STRING" , type_name) ) ;
2174+ }
2175+ }
2176+
2177+ Ok ( Some ( result) )
2178+ }
2179+
2180+ fn do_parse_string_to_int_try < T : Integer + CheckedSub + CheckedNeg + From < u8 > + Copy > (
2181+ str : & str ,
2182+ min_value : T ,
2183+ ) -> SparkResult < Option < T > > {
2184+ let bytes = str. as_bytes ( ) ;
2185+ let ( start, end) = trim_whitespace ( bytes) ;
2186+
2187+ if start == end {
2188+ return Ok ( None ) ;
2189+ }
2190+ let trimmed_bytes = & bytes[ start..end] ;
2191+
2192+ let ( negative, idx) = match parse_sign ( trimmed_bytes) {
2193+ Some ( result) => result,
2194+ None => return Ok ( None ) ,
2195+ } ;
2196+
2197+ let mut result: T = T :: zero ( ) ;
2198+
2199+ let radix = T :: from ( 10_u8 ) ;
2200+ let stop_value = min_value / radix;
2201+
2202+ // we don't have to go beyond decimal point in try eval mode - early return NULL
2203+ for & ch in & trimmed_bytes[ idx..] {
2204+ if ch == b'.' {
2205+ return Ok ( None ) ;
2206+ }
2207+
2208+ if !ch. is_ascii_digit ( ) {
2209+ return Ok ( None ) ;
2210+ }
2211+
2212+ let digit: T = T :: from ( ch - b'0' ) ;
2213+
2214+ if result < stop_value {
2215+ return Ok ( None ) ;
2216+ }
2217+ let v = result * radix;
2218+ match v. checked_sub ( & digit) {
2219+ Some ( x) if x <= T :: zero ( ) => result = x,
2220+ _ => {
2221+ return Ok ( None ) ;
20782222 }
20792223 }
20802224 }
20812225
20822226 if !negative {
20832227 if let Some ( neg) = result. checked_neg ( ) {
20842228 if neg < T :: zero ( ) {
2085- return none_or_err ( eval_mode , type_name , str ) ;
2229+ return Ok ( None ) ;
20862230 }
20872231 result = neg;
20882232 } else {
2089- return none_or_err ( eval_mode , type_name , str ) ;
2233+ return Ok ( None ) ;
20902234 }
20912235 }
20922236
20932237 Ok ( Some ( result) )
20942238}
20952239
2240+ fn do_cast_string_to_int < T : Integer + CheckedSub + CheckedNeg + From < u8 > + Copy > (
2241+ str : & str ,
2242+ eval_mode : EvalMode ,
2243+ type_name : & str ,
2244+ min_value : T ,
2245+ ) -> SparkResult < Option < T > > {
2246+ match eval_mode {
2247+ EvalMode :: Legacy => do_parse_string_to_int_legacy ( str, min_value) ,
2248+ EvalMode :: Ansi => do_parse_string_to_int_ansi ( str, type_name, min_value) ,
2249+ EvalMode :: Try => do_parse_string_to_int_try ( str, min_value) ,
2250+ }
2251+ }
2252+
20962253fn cast_string_to_decimal (
20972254 array : & ArrayRef ,
20982255 to_type : & DataType ,
@@ -2393,15 +2550,6 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
23932550 Ok ( ( final_mantissa, final_scale) )
23942551}
23952552
2396- /// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
2397- #[ inline]
2398- fn none_or_err < T > ( eval_mode : EvalMode , type_name : & str , str : & str ) -> SparkResult < Option < T > > {
2399- match eval_mode {
2400- EvalMode :: Ansi => Err ( invalid_value ( str, "STRING" , type_name) ) ,
2401- _ => Ok ( None ) ,
2402- }
2403- }
2404-
24052553#[ inline]
24062554fn invalid_value ( value : & str , from_type : & str , to_type : & str ) -> SparkError {
24072555 SparkError :: CastInvalidValue {
0 commit comments