@@ -31,8 +31,8 @@ use arrow::{
3131 } ,
3232 compute:: { cast_with_options, take, unary, CastOptions } ,
3333 datatypes:: {
34- ArrowPrimitiveType , Decimal128Type , DecimalType , Float32Type , Float64Type , Int64Type ,
35- TimestampMicrosecondType ,
34+ is_validate_decimal_precision , ArrowPrimitiveType , Decimal128Type , Float32Type ,
35+ Float64Type , Int64Type , TimestampMicrosecondType ,
3636 } ,
3737 error:: ArrowError ,
3838 record_batch:: RecordBatch ,
@@ -1287,38 +1287,25 @@ where
12871287 for i in 0 ..input. len ( ) {
12881288 if input. is_null ( i) {
12891289 cast_array. append_null ( ) ;
1290- } else {
1291- let input_value = input. value ( i) . as_ ( ) ;
1292- let value = ( input_value * mul) . round ( ) . to_i128 ( ) ;
1293-
1294- match value {
1295- Some ( v) => {
1296- if Decimal128Type :: validate_decimal_precision ( v, precision) . is_err ( ) {
1297- if eval_mode == EvalMode :: Ansi {
1298- return Err ( SparkError :: NumericValueOutOfRange {
1299- value : input_value. to_string ( ) ,
1300- precision,
1301- scale,
1302- } ) ;
1303- } else {
1304- cast_array. append_null ( ) ;
1305- }
1306- }
1307- cast_array. append_value ( v) ;
1308- }
1309- None => {
1310- if eval_mode == EvalMode :: Ansi {
1311- return Err ( SparkError :: NumericValueOutOfRange {
1312- value : input_value. to_string ( ) ,
1313- precision,
1314- scale,
1315- } ) ;
1316- } else {
1317- cast_array. append_null ( ) ;
1318- }
1319- }
1290+ continue ;
1291+ }
1292+
1293+ let input_value = input. value ( i) . as_ ( ) ;
1294+ if let Some ( v) = ( input_value * mul) . round ( ) . to_i128 ( ) {
1295+ if is_validate_decimal_precision ( v, precision) {
1296+ cast_array. append_value ( v) ;
1297+ continue ;
13201298 }
1299+ } ;
1300+
1301+ if eval_mode == EvalMode :: Ansi {
1302+ return Err ( SparkError :: NumericValueOutOfRange {
1303+ value : input_value. to_string ( ) ,
1304+ precision,
1305+ scale,
1306+ } ) ;
13211307 }
1308+ cast_array. append_null ( ) ;
13221309 }
13231310
13241311 let res = Arc :: new (
@@ -2203,6 +2190,7 @@ mod tests {
22032190 use arrow:: array:: StringArray ;
22042191 use arrow:: datatypes:: TimestampMicrosecondType ;
22052192 use arrow:: datatypes:: { Field , Fields , TimeUnit } ;
2193+ use core:: f64;
22062194 use std:: str:: FromStr ;
22072195
22082196 use super :: * ;
@@ -2671,4 +2659,35 @@ mod tests {
26712659 unreachable ! ( )
26722660 }
26732661 }
2662+
2663+ #[ test]
2664+ fn test_cast_float_to_decimal ( ) {
2665+ let a: ArrayRef = Arc :: new ( Float64Array :: from ( vec ! [
2666+ Some ( 42. ) ,
2667+ Some ( 0.5153125 ) ,
2668+ Some ( -42.4242415 ) ,
2669+ Some ( 42e-314 ) ,
2670+ Some ( 0. ) ,
2671+ Some ( -4242.424242 ) ,
2672+ Some ( f64 :: INFINITY ) ,
2673+ Some ( f64 :: NEG_INFINITY ) ,
2674+ Some ( f64 :: NAN ) ,
2675+ None ,
2676+ ] ) ) ;
2677+ let b =
2678+ cast_floating_point_to_decimal128 :: < Float64Type > ( & a, 8 , 6 , EvalMode :: Legacy ) . unwrap ( ) ;
2679+ assert_eq ! ( b. len( ) , a. len( ) ) ;
2680+ let casted = b. as_primitive :: < Decimal128Type > ( ) ;
2681+ assert_eq ! ( casted. value( 0 ) , 42000000 ) ;
2682+ // https://github.com/apache/datafusion-comet/issues/1371
2683+ // assert_eq!(casted.value(1), 515313);
2684+ assert_eq ! ( casted. value( 2 ) , -42424242 ) ;
2685+ assert_eq ! ( casted. value( 3 ) , 0 ) ;
2686+ assert_eq ! ( casted. value( 4 ) , 0 ) ;
2687+ assert ! ( casted. is_null( 5 ) ) ;
2688+ assert ! ( casted. is_null( 6 ) ) ;
2689+ assert ! ( casted. is_null( 7 ) ) ;
2690+ assert ! ( casted. is_null( 8 ) ) ;
2691+ assert ! ( casted. is_null( 9 ) ) ;
2692+ }
26742693}
0 commit comments