Skip to content

Commit 7148d1c

Browse files
authored
feat: Add support for Spark-compatible cast from integral to decimal (#2472)
1 parent 46057a7 commit 7148d1c

File tree

4 files changed

+128
-14
lines changed

4 files changed

+128
-14
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,15 @@ The following cast operations are generally compatible with Spark except for the
169169
| integer | long | |
170170
| integer | float | |
171171
| integer | double | |
172+
| integer | decimal | |
172173
| integer | string | |
173174
| long | boolean | |
174175
| long | byte | |
175176
| long | short | |
176177
| long | integer | |
177178
| long | float | |
178179
| long | double | |
180+
| long | decimal | |
179181
| long | string | |
180182
| float | boolean | |
181183
| float | byte | |
@@ -222,8 +224,6 @@ The following cast operations are not compatible with Spark for all inputs and a
222224
<!--BEGIN:INCOMPAT_CAST_TABLE-->
223225
| From Type | To Type | Notes |
224226
|-|-|-|
225-
| integer | decimal | No overflow check |
226-
| long | decimal | No overflow check |
227227
| float | decimal | There can be rounding differences |
228228
| double | decimal | There can be rounding differences |
229229
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use crate::utils::array_with_timezone;
1919
use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
22-
use arrow::array::{DictionaryArray, GenericByteArray, StringArray, StructArray};
22+
use arrow::array::{
23+
Decimal128Builder, DictionaryArray, GenericByteArray, StringArray, StructArray,
24+
};
2325
use arrow::compute::can_cast_types;
2426
use arrow::datatypes::{
2527
ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema,
@@ -983,6 +985,9 @@ fn cast_array(
983985
{
984986
spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
985987
}
988+
(Int8 | Int16 | Int32 | Int64, Decimal128(precision, scale)) => {
989+
cast_int_to_decimal128(&array, eval_mode, from_type, to_type, *precision, *scale)
990+
}
986991
(Utf8, Int8 | Int16 | Int32 | Int64) => {
987992
cast_string_to_int::<i32>(to_type, &array, eval_mode)
988993
}
@@ -1143,9 +1148,6 @@ fn is_datafusion_spark_compatible(
11431148
| DataType::Utf8
11441149
),
11451150
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1146-
// note that the cast from Int32/Int64 -> Decimal128 here is actually
1147-
// not compatible with Spark (no overflow checks) but we have tests that
1148-
// rely on this cast working, so we have to leave it here for now
11491151
matches!(
11501152
to_type,
11511153
DataType::Boolean
@@ -1155,7 +1157,6 @@ fn is_datafusion_spark_compatible(
11551157
| DataType::Int64
11561158
| DataType::Float32
11571159
| DataType::Float64
1158-
| DataType::Decimal128(_, _)
11591160
| DataType::Utf8
11601161
)
11611162
}
@@ -1464,6 +1465,108 @@ where
14641465
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
14651466
}
14661467

1468+
fn cast_int_to_decimal128_internal<T>(
1469+
array: &PrimitiveArray<T>,
1470+
precision: u8,
1471+
scale: i8,
1472+
eval_mode: EvalMode,
1473+
) -> SparkResult<ArrayRef>
1474+
where
1475+
T: ArrowPrimitiveType,
1476+
T::Native: Into<i128>,
1477+
{
1478+
let mut builder = Decimal128Builder::with_capacity(array.len());
1479+
let multiplier = 10_i128.pow(scale as u32);
1480+
1481+
for i in 0..array.len() {
1482+
if array.is_null(i) {
1483+
builder.append_null();
1484+
} else {
1485+
let v = array.value(i).into();
1486+
let scaled = v.checked_mul(multiplier);
1487+
match scaled {
1488+
Some(scaled) => {
1489+
if !is_validate_decimal_precision(scaled, precision) {
1490+
match eval_mode {
1491+
EvalMode::Ansi => {
1492+
return Err(SparkError::NumericValueOutOfRange {
1493+
value: v.to_string(),
1494+
precision,
1495+
scale,
1496+
});
1497+
}
1498+
EvalMode::Try | EvalMode::Legacy => builder.append_null(),
1499+
}
1500+
} else {
1501+
builder.append_value(scaled);
1502+
}
1503+
}
1504+
_ => match eval_mode {
1505+
EvalMode::Ansi => {
1506+
return Err(SparkError::NumericValueOutOfRange {
1507+
value: v.to_string(),
1508+
precision,
1509+
scale,
1510+
})
1511+
}
1512+
EvalMode::Legacy | EvalMode::Try => builder.append_null(),
1513+
},
1514+
}
1515+
}
1516+
}
1517+
Ok(Arc::new(
1518+
builder.with_precision_and_scale(precision, scale)?.finish(),
1519+
))
1520+
}
1521+
1522+
fn cast_int_to_decimal128(
1523+
array: &dyn Array,
1524+
eval_mode: EvalMode,
1525+
from_type: &DataType,
1526+
to_type: &DataType,
1527+
precision: u8,
1528+
scale: i8,
1529+
) -> SparkResult<ArrayRef> {
1530+
match (from_type, to_type) {
1531+
(DataType::Int8, DataType::Decimal128(_p, _s)) => {
1532+
cast_int_to_decimal128_internal::<Int8Type>(
1533+
array.as_primitive::<Int8Type>(),
1534+
precision,
1535+
scale,
1536+
eval_mode,
1537+
)
1538+
}
1539+
(DataType::Int16, DataType::Decimal128(_p, _s)) => {
1540+
cast_int_to_decimal128_internal::<Int16Type>(
1541+
array.as_primitive::<Int16Type>(),
1542+
precision,
1543+
scale,
1544+
eval_mode,
1545+
)
1546+
}
1547+
(DataType::Int32, DataType::Decimal128(_p, _s)) => {
1548+
cast_int_to_decimal128_internal::<Int32Type>(
1549+
array.as_primitive::<Int32Type>(),
1550+
precision,
1551+
scale,
1552+
eval_mode,
1553+
)
1554+
}
1555+
(DataType::Int64, DataType::Decimal128(_p, _s)) => {
1556+
cast_int_to_decimal128_internal::<Int64Type>(
1557+
array.as_primitive::<Int64Type>(),
1558+
precision,
1559+
scale,
1560+
eval_mode,
1561+
)
1562+
}
1563+
_ => Err(SparkError::Internal(format!(
1564+
"Unsupported cast from datatype : {}",
1565+
from_type
1566+
))),
1567+
}
1568+
}
1569+
14671570
fn spark_cast_int_to_int(
14681571
array: &dyn Array,
14691572
eval_mode: EvalMode,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
284284
case DataTypes.FloatType | DataTypes.DoubleType =>
285285
Compatible()
286286
case _: DecimalType =>
287-
Incompatible(Some("No overflow check"))
287+
Compatible()
288288
case _ =>
289289
unsupported(DataTypes.IntegerType, toType)
290290
}
@@ -297,7 +297,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
297297
case DataTypes.FloatType | DataTypes.DoubleType =>
298298
Compatible()
299299
case _: DecimalType =>
300-
Incompatible(Some("No overflow check"))
300+
Compatible()
301301
case _ =>
302302
unsupported(DataTypes.LongType, toType)
303303
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
322322
castTest(generateInts(), DataTypes.DoubleType)
323323
}
324324

325-
ignore("cast IntegerType to DecimalType(10,2)") {
326-
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336 cannot be represented as Decimal(10, 2)
325+
test("cast IntegerType to DecimalType(10,2)") {
327326
castTest(generateInts(), DataTypes.createDecimalType(10, 2))
328327
}
329328

329+
test("cast IntegerType to DecimalType(10,2) overflow check") {
330+
val intToDecimal10OverflowValues =
331+
Seq(Int.MinValue, -100000000, -100000001, 100000000, 100000001, Int.MaxValue).toDF("a")
332+
castTest(intToDecimal10OverflowValues, DataTypes.createDecimalType(10, 2))
333+
}
334+
335+
test("cast IntegerType to DecimalType check arbitrary scale and precision") {
336+
Seq(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE, 0, 10, 15)
337+
.combinations(2)
338+
.map({ c =>
339+
castTest(generateInts(), DataTypes.createDecimalType(c.head, c.last))
340+
})
341+
}
342+
330343
test("cast IntegerType to StringType") {
331344
castTest(generateInts(), DataTypes.StringType)
332345
}
@@ -369,8 +382,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
369382
castTest(generateLongs(), DataTypes.DoubleType)
370383
}
371384

372-
ignore("cast LongType to DecimalType(10,2)") {
373-
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336 cannot be represented as Decimal(10, 2)
385+
test("cast LongType to DecimalType(10,2)") {
374386
castTest(generateLongs(), DataTypes.createDecimalType(10, 2))
375387
}
376388

@@ -1232,7 +1244,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12321244
case (None, None) =>
12331245
// neither system threw an exception
12341246
case (None, Some(e)) =>
1235-
// Spark succeeded but Comet failed
12361247
throw e
12371248
case (Some(e), None) =>
12381249
// Spark failed but Comet succeeded

0 commit comments

Comments
 (0)