Skip to content

Commit eb686bf

Browse files
committed
int_to_binary_boolean_to_decimal
1 parent 0e75965 commit eb686bf

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::utils::array_with_timezone;
19+
use crate::EvalMode::Legacy;
1920
use crate::{timezone, BinaryOutputStyle};
2021
use crate::{EvalMode, SparkError, SparkResult};
2122
use arrow::array::builder::StringBuilder;
@@ -25,8 +26,8 @@ use arrow::array::{
2526
};
2627
use arrow::compute::can_cast_types;
2728
use arrow::datatypes::{
28-
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type,
29-
GenericBinaryType, Schema,
29+
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
30+
Schema,
3031
};
3132
use arrow::{
3233
array::{
@@ -66,7 +67,6 @@ use std::{
6667
num::Wrapping,
6768
sync::Arc,
6869
};
69-
use crate::EvalMode::Legacy;
7070

7171
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
7272

@@ -305,7 +305,10 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b
305305

306306
fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
307307
use DataType::*;
308-
matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
308+
matches!(
309+
to_type,
310+
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
311+
)
309312
}
310313

311314
fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
@@ -1121,9 +1124,18 @@ fn cast_array(
11211124
}
11221125
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
11231126
(Int8, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int8Array, 1),
1124-
(Int16, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int16Array, 2),
1125-
(Int32, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int32Array, 4),
1126-
(Int64, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int64Array, 8),
1127+
(Int16, Binary) if (eval_mode == Legacy) => {
1128+
cast_whole_num_to_binary!(&array, Int16Array, 2)
1129+
}
1130+
(Int32, Binary) if (eval_mode == Legacy) => {
1131+
cast_whole_num_to_binary!(&array, Int32Array, 4)
1132+
}
1133+
(Int64, Binary) if (eval_mode == Legacy) => {
1134+
cast_whole_num_to_binary!(&array, Int64Array, 8)
1135+
}
1136+
(Boolean, Decimal128(precision, scale)) => {
1137+
cast_boolean_to_decimal(&array, *precision, *scale)
1138+
}
11271139
_ if cast_options.is_adapting_schema
11281140
|| is_datafusion_spark_compatible(from_type, to_type) =>
11291141
{
@@ -1142,6 +1154,16 @@ fn cast_array(
11421154
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
11431155
}
11441156

1157+
fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult<ArrayRef> {
1158+
let bool_array = array.as_boolean();
1159+
let scale_factor = 10_i128.pow(scale as u32);
1160+
let result: Decimal128Array = bool_array
1161+
.iter()
1162+
.map(|v| v.map(|b| if b { scale_factor } else { 0 }))
1163+
.collect();
1164+
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
1165+
}
1166+
11451167
fn cast_string_to_float(
11461168
array: &ArrayRef,
11471169
to_type: &DataType,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.expressions
2121

2222
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Literal}
2323
import org.apache.spark.sql.internal.SQLConf
24-
import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, DecimalType, NullType, StructType}
24+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2525

2626
import org.apache.comet.CometConf
2727
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -263,7 +263,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
263263

264264
private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
265265
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
266-
DataTypes.FloatType | DataTypes.DoubleType =>
266+
DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
267267
Compatible()
268268
case _ => unsupported(DataTypes.BooleanType, toType)
269269
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
135135
castTest(generateBools(), DataTypes.DoubleType)
136136
}
137137

138-
ignore("cast BooleanType to DecimalType(10,2)") {
139-
// Arrow error: Cast error: Casting from Boolean to Decimal128(10, 2) not supported
138+
test("cast BooleanType to DecimalType(10,2)") {
140139
castTest(generateBools(), DataTypes.createDecimalType(10, 2))
141140
}
142141

142+
test("cast BooleanType to DecimalType(14,4)") {
143+
castTest(generateBools(), DataTypes.createDecimalType(14, 4))
144+
}
145+
146+
test("cast BooleanType to DecimalType(30,0)") {
147+
castTest(generateBools(), DataTypes.createDecimalType(30, 0))
148+
}
149+
143150
test("cast BooleanType to StringType") {
144151
castTest(generateBools(), DataTypes.StringType)
145152
}
@@ -1353,9 +1360,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13531360
}
13541361

13551362
if (testTry) {
1363+
data.createOrReplaceTempView("t")
13561364
// try_cast() should always return null for invalid inputs
1365+
// not using spark DSL since it `try_cast` is only available from Spark 4x
13571366
val df2 =
1358-
data.select(col("a"), col("a").try_cast(toType)).orderBy(col("a"))
1367+
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
13591368
if (hasIncompatibleType) {
13601369
checkSparkAnswer(df2)
13611370
} else {
@@ -1419,8 +1428,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14191428

14201429
// try_cast() should always return null for invalid inputs
14211430
if (testTry) {
1431+
data.createOrReplaceTempView("t")
14221432
val df2 =
1423-
data.select(col("a"), col("a").try_cast(toType)).orderBy(col("a"))
1433+
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
14241434
if (hasIncompatibleType) {
14251435
checkSparkAnswer(df2)
14261436
} else {

0 commit comments

Comments
 (0)