Skip to content

Commit 9d1f22d

Browse files
committed
fix: format decimal to string when casting to short
1 parent 0556f5e commit 9d1f22d

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,13 @@ macro_rules! cast_decimal_to_int16_down {
692692
.map(|value| match value {
693693
Some(value) => {
694694
let divisor = 10_i128.pow($scale as u32);
695-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
695+
let truncated = value / divisor;
696696
let is_overflow = truncated.abs() > i32::MAX.into();
697+
let fmt_str =
698+
format_decimal_str(&value.to_string(), $precision as usize, $scale);
697699
if is_overflow {
698700
return Err(cast_overflow(
699-
&format!("{}.{}BD", truncated, decimal),
701+
&format!("{}BD", fmt_str),
700702
&format!("DECIMAL({},{})", $precision, $scale),
701703
$dest_type_str,
702704
));
@@ -705,7 +707,7 @@ macro_rules! cast_decimal_to_int16_down {
705707
<$rust_dest_type>::try_from(i32_value)
706708
.map_err(|_| {
707709
cast_overflow(
708-
&format!("{}.{}BD", truncated, decimal),
710+
&format!("{}BD", fmt_str),
709711
&format!("DECIMAL({},{})", $precision, $scale),
710712
$dest_type_str,
711713
)
@@ -787,6 +789,30 @@ macro_rules! cast_decimal_to_int32_up {
787789
}};
788790
}
789791

792+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
793+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
794+
let (sign, rest) = match value_str.strip_prefix('-') {
795+
Some(stripped) => ("-", stripped),
796+
None => ("", value_str),
797+
};
798+
let bound = precision.min(rest.len()) + sign.len();
799+
let value_str = &value_str[0..bound];
800+
801+
if scale == 0 {
802+
value_str.to_string()
803+
} else if scale < 0 {
804+
let padding = value_str.len() + scale.unsigned_abs() as usize;
805+
format!("{value_str:0<padding$}")
806+
} else if rest.len() > scale as usize {
807+
// Decimal separator is in the middle of the string
808+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
809+
format!("{whole}.{decimal}")
810+
} else {
811+
// String has to be padded
812+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
813+
}
814+
}
815+
790816
impl Cast {
791817
pub fn new(
792818
child: Arc<dyn PhysicalExpr>,
@@ -1794,12 +1820,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
17941820
),
17951821
(DataType::Decimal128(precision, scale), DataType::Int8) => {
17961822
cast_decimal_to_int16_down!(
1797-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1823+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
17981824
)
17991825
}
18001826
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18011827
cast_decimal_to_int16_down!(
1802-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1828+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18031829
)
18041830
}
18051831
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
529529

530530
test("cast DecimalType(10,2) to ShortType") {
531531
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
532+
castTest(
533+
generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
534+
DataTypes.ShortType)
532535
}
533536

534537
test("cast DecimalType(10,2) to IntegerType") {
@@ -1135,6 +1138,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11351138
BigDecimal("32768.678"),
11361139
BigDecimal("123456.789"),
11371140
BigDecimal("99999999.999"))
1141+
generateDecimalsPrecision10Scale2(values)
1142+
}
1143+
1144+
private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = {
11381145
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
11391146
}
11401147

0 commit comments

Comments
 (0)