diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index dd059abbcd..0c5e1bcde0 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -173,6 +173,7 @@ The following cast operations are generally compatible with Spark except for the | integer | long | | | integer | float | | | integer | double | | +| integer | decimal | | | integer | string | | | long | boolean | | | long | byte | | @@ -180,6 +181,7 @@ The following cast operations are generally compatible with Spark except for the | long | integer | | | long | float | | | long | double | | +| long | decimal | | | long | string | | | float | boolean | | | float | byte | | @@ -226,8 +228,6 @@ The following cast operations are not compatible with Spark for all inputs and a | From Type | To Type | Notes | |-|-|-| -| integer | decimal | No overflow check | -| long | decimal | No overflow check | | float | decimal | There can be rounding differences | | double | decimal | There can be rounding differences | | string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index af997ccf80..da568548aa 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -19,7 +19,9 @@ use crate::utils::array_with_timezone; use crate::{timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; -use arrow::array::{DictionaryArray, GenericByteArray, StringArray, StructArray}; +use arrow::array::{ + Decimal128Builder, DictionaryArray, GenericByteArray, StringArray, StructArray, +}; use arrow::compute::can_cast_types; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema, @@ -983,6 +985,9 @@ fn cast_array( { spark_cast_int_to_int(&array, eval_mode, from_type, to_type) } + (Int8 | Int16 | Int32 | Int64, Decimal128(precision, scale)) => { + cast_int_to_decimal128(&array, eval_mode, from_type, to_type, *precision, *scale) + } (Utf8, Int8 | Int16 | Int32 | Int64) => { cast_string_to_int::(to_type, &array, eval_mode) } @@ -1143,9 +1148,6 @@ fn is_datafusion_spark_compatible( | DataType::Utf8 ), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - // note that the cast from Int32/Int64 -> Decimal128 here is actually - // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working, so we have to leave it here for now matches!( to_type, DataType::Boolean @@ -1155,7 +1157,6 @@ fn is_datafusion_spark_compatible( | DataType::Int64 | DataType::Float32 | DataType::Float64 - | DataType::Decimal128(_, _) | DataType::Utf8 ) } @@ -1464,6 +1465,108 @@ where cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) } +fn cast_int_to_decimal128_internal( + array: &PrimitiveArray, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + T: ArrowPrimitiveType, + T::Native: Into, +{ + let mut builder = Decimal128Builder::with_capacity(array.len()); + let multiplier = 10_i128.pow(scale as u32); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let v = array.value(i).into(); + let scaled = v.checked_mul(multiplier); + match scaled { + Some(scaled) => { + if !is_validate_decimal_precision(scaled, precision) { + match eval_mode { + EvalMode::Ansi => { + return Err(SparkError::NumericValueOutOfRange { + value: v.to_string(), + precision, + scale, + }); + } + EvalMode::Try | EvalMode::Legacy => builder.append_null(), + } + } else { + builder.append_value(scaled); + } + } + _ => match eval_mode { + EvalMode::Ansi => { + return Err(SparkError::NumericValueOutOfRange { + value: v.to_string(), + precision, + scale, + }) + } + EvalMode::Legacy | EvalMode::Try => builder.append_null(), + }, + } + } + } + Ok(Arc::new( + builder.with_precision_and_scale(precision, scale)?.finish(), + )) +} + +fn cast_int_to_decimal128( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + precision: u8, + scale: i8, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int8, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int16, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int32, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int64, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + _ => Err(SparkError::Internal(format!( + "Unsupported cast from datatype : {}", + from_type + ))), + } +} + fn spark_cast_int_to_int( array: &dyn Array, eval_mode: EvalMode, diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 7db62130d4..c68fbf0512 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -284,7 +284,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case DataTypes.FloatType | DataTypes.DoubleType => Compatible() case _: DecimalType => - Incompatible(Some("No overflow check")) + Compatible() case _ => unsupported(DataTypes.IntegerType, toType) } @@ -297,7 +297,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case DataTypes.FloatType | DataTypes.DoubleType => Compatible() case _: DecimalType => - Incompatible(Some("No overflow check")) + Compatible() case _ => unsupported(DataTypes.LongType, toType) } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 2667b40877..c5a4309a7d 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -322,11 +322,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateInts(), DataTypes.DoubleType) } - ignore("cast IntegerType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336 cannot be represented as Decimal(10, 2) + test("cast IntegerType to DecimalType(10,2)") { castTest(generateInts(), DataTypes.createDecimalType(10, 2)) } + test("cast IntegerType to DecimalType(10,2) overflow check") { + val intToDecimal10OverflowValues = + Seq(Int.MinValue, -100000000, -100000001, 100000000, 100000001, Int.MaxValue).toDF("a") + castTest(intToDecimal10OverflowValues, DataTypes.createDecimalType(10, 2)) + } + + test("cast IntegerType to DecimalType check arbitrary scale and precision") { + Seq(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE, 0, 10, 15) + .combinations(2) + .map({ c => + castTest(generateInts(), DataTypes.createDecimalType(c.head, c.last)) + }) + } + test("cast IntegerType to StringType") { castTest(generateInts(), DataTypes.StringType) } @@ -369,8 +382,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs(), DataTypes.DoubleType) } - ignore("cast LongType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] -1117686336 cannot be represented as Decimal(10, 2) + test("cast LongType to DecimalType(10,2)") { castTest(generateLongs(), DataTypes.createDecimalType(10, 2)) } @@ -1232,7 +1244,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (None, None) => // neither system threw an exception case (None, Some(e)) => - // Spark succeeded but Comet failed throw e case (Some(e), None) => // Spark failed but Comet succeeded