Skip to content

Commit 5ea5863

Browse files
authored
feat: support_decimal_types_bool_cast_native_impl (#2490)
1 parent 5372de3 commit 5ea5863

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ The following cast operations are generally compatible with Spark except for the
193193
| double | long | |
194194
| double | float | |
195195
| double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
196+
| decimal | boolean | |
196197
| decimal | byte | |
197198
| decimal | short | |
198199
| decimal | integer | |

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{
23-
Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray, StructArray,
23+
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray,
24+
StructArray,
2425
};
2526
use arrow::compute::can_cast_types;
2627
use arrow::datatypes::{
@@ -52,7 +53,7 @@ use datafusion::physical_expr::PhysicalExpr;
5253
use datafusion::physical_plan::ColumnarValue;
5354
use num::{
5455
cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num,
55-
ToPrimitive,
56+
ToPrimitive, Zero,
5657
};
5758
use regex::Regex;
5859
use std::str::FromStr;
@@ -1020,6 +1021,7 @@ fn cast_array(
10201021
{
10211022
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
10221023
}
1024+
(Decimal128(_p, _s), Boolean) => spark_cast_decimal_to_boolean(&array),
10231025
(Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
10241026
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
10251027
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
@@ -1678,6 +1680,19 @@ where
16781680
Ok(Arc::new(output_array))
16791681
}
16801682

1683+
fn spark_cast_decimal_to_boolean(array: &dyn Array) -> SparkResult<ArrayRef> {
1684+
let decimal_array = array.as_primitive::<Decimal128Type>();
1685+
let mut result = BooleanBuilder::with_capacity(decimal_array.len());
1686+
for i in 0..decimal_array.len() {
1687+
if decimal_array.is_null(i) {
1688+
result.append_null()
1689+
} else {
1690+
result.append_value(!decimal_array.value(i).is_zero());
1691+
}
1692+
}
1693+
Ok(Arc::new(result.finish()))
1694+
}
1695+
16811696
fn spark_cast_nonintegral_numeric_to_integral(
16821697
array: &dyn Array,
16831698
eval_mode: EvalMode,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
338338

339339
private def canCastFromDecimal(toType: DataType): SupportLevel = toType match {
340340
case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
341-
DataTypes.IntegerType | DataTypes.LongType =>
341+
DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType =>
342342
Compatible()
343343
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported"))
344344
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
519519

520520
// CAST from DecimalType(10,2)
521521

522-
ignore("cast DecimalType(10,2) to BooleanType") {
523-
// Arrow error: Cast error: Casting from Decimal128(38, 18) to Boolean not supported
522+
test("cast DecimalType(10,2) to BooleanType") {
524523
castTest(generateDecimalsPrecision10Scale2(), DataTypes.BooleanType)
525524
}
526525

0 commit comments

Comments
 (0)