Skip to content

Commit 26c073e

Browse files
committed
fix: format decimal to string when casting to short
1 parent a9d0c2b commit 26c073e

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -691,11 +691,18 @@ macro_rules! cast_decimal_to_int16_down {
691691
.map(|value| match value {
692692
Some(value) => {
693693
let divisor = 10_i128.pow($scale as u32);
694-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
694+
let truncated = value / divisor;
695695
let is_overflow = truncated.abs() > i32::MAX.into();
696696
if is_overflow {
697697
return Err(cast_overflow(
698-
&format!("{}.{}BD", truncated, decimal),
698+
&format!(
699+
"{}BD",
700+
format_decimal_str(
701+
&value.to_string(),
702+
$precision as usize,
703+
$scale
704+
)
705+
),
699706
&format!("DECIMAL({},{})", $precision, $scale),
700707
$dest_type_str,
701708
));
@@ -704,7 +711,14 @@ macro_rules! cast_decimal_to_int16_down {
704711
<$rust_dest_type>::try_from(i32_value)
705712
.map_err(|_| {
706713
cast_overflow(
707-
&format!("{}.{}BD", truncated, decimal),
714+
&format!(
715+
"{}BD",
716+
format_decimal_str(
717+
&value.to_string(),
718+
$precision as usize,
719+
$scale
720+
)
721+
),
708722
&format!("DECIMAL({},{})", $precision, $scale),
709723
$dest_type_str,
710724
)
@@ -786,6 +800,30 @@ macro_rules! cast_decimal_to_int32_up {
786800
}};
787801
}
788802

803+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
804+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
805+
let (sign, rest) = match value_str.strip_prefix('-') {
806+
Some(stripped) => ("-", stripped),
807+
None => ("", value_str),
808+
};
809+
let bound = precision.min(rest.len()) + sign.len();
810+
let value_str = &value_str[0..bound];
811+
812+
if scale == 0 {
813+
value_str.to_string()
814+
} else if scale < 0 {
815+
let padding = value_str.len() + scale.unsigned_abs() as usize;
816+
format!("{value_str:0<padding$}")
817+
} else if rest.len() > scale as usize {
818+
// Decimal separator is in the middle of the string
819+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
820+
format!("{whole}.{decimal}")
821+
} else {
822+
// String has to be padded
823+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
824+
}
825+
}
826+
789827
impl Cast {
790828
pub fn new(
791829
child: Arc<dyn PhysicalExpr>,
@@ -1799,12 +1837,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
17991837
),
18001838
(DataType::Decimal128(precision, scale), DataType::Int8) => {
18011839
cast_decimal_to_int16_down!(
1802-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1840+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
18031841
)
18041842
}
18051843
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18061844
cast_decimal_to_int16_down!(
1807-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1845+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18081846
)
18091847
}
18101848
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
529529

530530
test("cast DecimalType(10,2) to ShortType") {
531531
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
532+
castTest(
533+
generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
534+
DataTypes.ShortType)
532535
}
533536

534537
test("cast DecimalType(10,2) to IntegerType") {
@@ -1189,6 +1192,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11891192
BigDecimal("32768.678"),
11901193
BigDecimal("123456.789"),
11911194
BigDecimal("99999999.999"))
1195+
generateDecimalsPrecision10Scale2(values)
1196+
}
1197+
1198+
private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = {
11921199
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
11931200
}
11941201

0 commit comments

Comments
 (0)