diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 5011917082..a2e12168dd 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -685,11 +685,18 @@ macro_rules! cast_decimal_to_int16_down { .map(|value| match value { Some(value) => { let divisor = 10_i128.pow($scale as u32); - let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let truncated = value / divisor; let is_overflow = truncated.abs() > i32::MAX.into(); if is_overflow { return Err(cast_overflow( - &format!("{}.{}BD", truncated, decimal), + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), &format!("DECIMAL({},{})", $precision, $scale), $dest_type_str, )); @@ -698,7 +705,14 @@ macro_rules! cast_decimal_to_int16_down { <$rust_dest_type>::try_from(i32_value) .map_err(|_| { cast_overflow( - &format!("{}.{}BD", truncated, decimal), + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), &format!("DECIMAL({},{})", $precision, $scale), $dest_type_str, ) @@ -748,11 +762,18 @@ macro_rules! cast_decimal_to_int32_up { .map(|value| match value { Some(value) => { let divisor = 10_i128.pow($scale as u32); - let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let truncated = value / divisor; let is_overflow = truncated.abs() > $max_dest_val.into(); if is_overflow { return Err(cast_overflow( - &format!("{}.{}BD", truncated, decimal), + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), &format!("DECIMAL({},{})", $precision, $scale), $dest_type_str, )); @@ -780,6 +801,30 @@ macro_rules! cast_decimal_to_int32_up { }}; } +// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + impl Cast { pub fn new( child: Arc, @@ -1866,12 +1911,12 @@ fn spark_cast_nonintegral_numeric_to_integral( ), (DataType::Decimal128(precision, scale), DataType::Int8) => { cast_decimal_to_int16_down!( - array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale ) } (DataType::Decimal128(precision, scale), DataType::Int16) => { cast_decimal_to_int16_down!( - array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale ) } (DataType::Decimal128(precision, scale), DataType::Int32) => { diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1892749bec..8a68df3820 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast DecimalType(10,2) to ShortType") { castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType) + castTest( + generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))), + DataTypes.ShortType) } test("cast DecimalType(10,2) to IntegerType") { @@ -553,14 +556,23 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast DecimalType(38,18) to ShortType") { castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType) + castTest( + generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))), + DataTypes.ShortType) } test("cast DecimalType(38,18) to IntegerType") { castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType) + castTest( + generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))), + DataTypes.IntegerType) } test("cast DecimalType(38,18) to LongType") { castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType) + castTest( + generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))), + DataTypes.LongType) } test("cast DecimalType(10,2) to StringType") { @@ -1205,6 +1217,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { BigDecimal("32768.678"), BigDecimal("123456.789"), BigDecimal("99999999.999")) + generateDecimalsPrecision10Scale2(values) + } + + private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = { withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b") } @@ -1227,6 +1243,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Long Max BigDecimal("9223372036854775808.234567"), BigDecimal("99999999999999999999.999999999999")) + generateDecimalsPrecision38Scale18(values) + } + + private def generateDecimalsPrecision38Scale18(values: Seq[BigDecimal]): DataFrame = { withNulls(values).toDF("a") }