Skip to content

Commit 726c4b7

Browse files
authored
fix: format decimal to string when casting decimal with overflow (#2916)
1 parent 37cb5c9 commit 726c4b7

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,18 @@ macro_rules! cast_decimal_to_int16_down {
685685
.map(|value| match value {
686686
Some(value) => {
687687
let divisor = 10_i128.pow($scale as u32);
688-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
688+
let truncated = value / divisor;
689689
let is_overflow = truncated.abs() > i32::MAX.into();
690690
if is_overflow {
691691
return Err(cast_overflow(
692-
&format!("{}.{}BD", truncated, decimal),
692+
&format!(
693+
"{}BD",
694+
format_decimal_str(
695+
&value.to_string(),
696+
$precision as usize,
697+
$scale
698+
)
699+
),
693700
&format!("DECIMAL({},{})", $precision, $scale),
694701
$dest_type_str,
695702
));
@@ -698,7 +705,14 @@ macro_rules! cast_decimal_to_int16_down {
698705
<$rust_dest_type>::try_from(i32_value)
699706
.map_err(|_| {
700707
cast_overflow(
701-
&format!("{}.{}BD", truncated, decimal),
708+
&format!(
709+
"{}BD",
710+
format_decimal_str(
711+
&value.to_string(),
712+
$precision as usize,
713+
$scale
714+
)
715+
),
702716
&format!("DECIMAL({},{})", $precision, $scale),
703717
$dest_type_str,
704718
)
@@ -748,11 +762,18 @@ macro_rules! cast_decimal_to_int32_up {
748762
.map(|value| match value {
749763
Some(value) => {
750764
let divisor = 10_i128.pow($scale as u32);
751-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
765+
let truncated = value / divisor;
752766
let is_overflow = truncated.abs() > $max_dest_val.into();
753767
if is_overflow {
754768
return Err(cast_overflow(
755-
&format!("{}.{}BD", truncated, decimal),
769+
&format!(
770+
"{}BD",
771+
format_decimal_str(
772+
&value.to_string(),
773+
$precision as usize,
774+
$scale
775+
)
776+
),
756777
&format!("DECIMAL({},{})", $precision, $scale),
757778
$dest_type_str,
758779
));
@@ -780,6 +801,30 @@ macro_rules! cast_decimal_to_int32_up {
780801
}};
781802
}
782803

804+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
805+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
806+
let (sign, rest) = match value_str.strip_prefix('-') {
807+
Some(stripped) => ("-", stripped),
808+
None => ("", value_str),
809+
};
810+
let bound = precision.min(rest.len()) + sign.len();
811+
let value_str = &value_str[0..bound];
812+
813+
if scale == 0 {
814+
value_str.to_string()
815+
} else if scale < 0 {
816+
let padding = value_str.len() + scale.unsigned_abs() as usize;
817+
format!("{value_str:0<padding$}")
818+
} else if rest.len() > scale as usize {
819+
// Decimal separator is in the middle of the string
820+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
821+
format!("{whole}.{decimal}")
822+
} else {
823+
// String has to be padded
824+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
825+
}
826+
}
827+
783828
impl Cast {
784829
pub fn new(
785830
child: Arc<dyn PhysicalExpr>,
@@ -1866,12 +1911,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
18661911
),
18671912
(DataType::Decimal128(precision, scale), DataType::Int8) => {
18681913
cast_decimal_to_int16_down!(
1869-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1914+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
18701915
)
18711916
}
18721917
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18731918
cast_decimal_to_int16_down!(
1874-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1919+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18751920
)
18761921
}
18771922
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
529529

530530
test("cast DecimalType(10,2) to ShortType") {
531531
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
532+
castTest(
533+
generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
534+
DataTypes.ShortType)
532535
}
533536

534537
test("cast DecimalType(10,2) to IntegerType") {
@@ -553,14 +556,23 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
553556

554557
test("cast DecimalType(38,18) to ShortType") {
555558
castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
559+
castTest(
560+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
561+
DataTypes.ShortType)
556562
}
557563

558564
test("cast DecimalType(38,18) to IntegerType") {
559565
castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
566+
castTest(
567+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
568+
DataTypes.IntegerType)
560569
}
561570

562571
test("cast DecimalType(38,18) to LongType") {
563572
castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
573+
castTest(
574+
generateDecimalsPrecision38Scale18(Seq(BigDecimal("-99999999999999999999.07"))),
575+
DataTypes.LongType)
564576
}
565577

566578
test("cast DecimalType(10,2) to StringType") {
@@ -1205,6 +1217,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12051217
BigDecimal("32768.678"),
12061218
BigDecimal("123456.789"),
12071219
BigDecimal("99999999.999"))
1220+
generateDecimalsPrecision10Scale2(values)
1221+
}
1222+
1223+
private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = {
12081224
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
12091225
}
12101226

@@ -1227,6 +1243,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12271243
// Long Max
12281244
BigDecimal("9223372036854775808.234567"),
12291245
BigDecimal("99999999999999999999.999999999999"))
1246+
generateDecimalsPrecision38Scale18(values)
1247+
}
1248+
1249+
private def generateDecimalsPrecision38Scale18(values: Seq[BigDecimal]): DataFrame = {
12301250
withNulls(values).toDF("a")
12311251
}
12321252

0 commit comments

Comments
 (0)