Skip to content

Commit d98b0be

Browse files
authored
feat: Implement array-to-string cast support (#2425)
1 parent 09dc7cc commit d98b0be

File tree

5 files changed

+151
-6
lines changed

5 files changed

+151
-6
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::{timezone, BinaryOutputStyle};
2020
use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{
23-
Decimal128Builder, DictionaryArray, GenericByteArray, StringArray, StructArray,
23+
Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray, StructArray,
2424
};
2525
use arrow::compute::can_cast_types;
2626
use arrow::datatypes::{
@@ -1028,6 +1028,7 @@ fn cast_array(
10281028
to_type,
10291029
cast_options,
10301030
)?),
1031+
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
10311032
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
10321033
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
10331034
}
@@ -1240,6 +1241,52 @@ fn cast_struct_to_struct(
12401241
}
12411242
}
12421243

1244+
fn cast_array_to_string(
1245+
array: &ListArray,
1246+
spark_cast_options: &SparkCastOptions,
1247+
) -> DataFusionResult<ArrayRef> {
1248+
let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
1249+
let mut str = String::with_capacity(array.len() * 16);
1250+
1251+
let casted_values = cast_array(
1252+
Arc::clone(array.values()),
1253+
&DataType::Utf8,
1254+
spark_cast_options,
1255+
)?;
1256+
let string_values = casted_values
1257+
.as_any()
1258+
.downcast_ref::<StringArray>()
1259+
.expect("Casted values should be StringArray");
1260+
1261+
let offsets = array.offsets();
1262+
for row_index in 0..array.len() {
1263+
if array.is_null(row_index) {
1264+
builder.append_null();
1265+
} else {
1266+
str.clear();
1267+
let start = offsets[row_index] as usize;
1268+
let end = offsets[row_index + 1] as usize;
1269+
1270+
str.push('[');
1271+
let mut first = true;
1272+
for idx in start..end {
1273+
if !first {
1274+
str.push_str(", ");
1275+
}
1276+
if string_values.is_null(idx) {
1277+
str.push_str(&spark_cast_options.null_string);
1278+
} else {
1279+
str.push_str(string_values.value(idx));
1280+
}
1281+
first = false;
1282+
}
1283+
str.push(']');
1284+
builder.append_value(&str);
1285+
}
1286+
}
1287+
Ok(Arc::new(builder.finish()))
1288+
}
1289+
12431290
fn casts_struct_to_string(
12441291
array: &StructArray,
12451292
spark_cast_options: &SparkCastOptions,
@@ -2928,4 +2975,55 @@ mod tests {
29282975
assert!(casted.is_null(8));
29292976
assert!(casted.is_null(9));
29302977
}
2978+
2979+
#[test]
2980+
fn test_cast_string_array_to_string() {
2981+
use arrow::array::ListArray;
2982+
use arrow::buffer::OffsetBuffer;
2983+
let values_array =
2984+
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
2985+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
2986+
let item_field = Arc::new(Field::new("item", DataType::Utf8, true));
2987+
let list_array = Arc::new(ListArray::new(
2988+
item_field,
2989+
offsets_buffer,
2990+
Arc::new(values_array),
2991+
None,
2992+
));
2993+
let string_array = cast_array_to_string(
2994+
&list_array,
2995+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
2996+
)
2997+
.unwrap();
2998+
let string_array = string_array.as_string::<i32>();
2999+
assert_eq!(r#"[a, b, c]"#, string_array.value(0));
3000+
assert_eq!(r#"[a, null]"#, string_array.value(1));
3001+
assert_eq!(r#"[null]"#, string_array.value(2));
3002+
assert_eq!(r#"[]"#, string_array.value(3));
3003+
}
3004+
3005+
#[test]
3006+
fn test_cast_i32_array_to_string() {
3007+
use arrow::array::ListArray;
3008+
use arrow::buffer::OffsetBuffer;
3009+
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
3010+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
3011+
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
3012+
let list_array = Arc::new(ListArray::new(
3013+
item_field,
3014+
offsets_buffer,
3015+
Arc::new(values_array),
3016+
None,
3017+
));
3018+
let string_array = cast_array_to_string(
3019+
&list_array,
3020+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
3021+
)
3022+
.unwrap();
3023+
let string_array = string_array.as_string::<i32>();
3024+
assert_eq!(r#"[1, 2, 3]"#, string_array.value(0));
3025+
assert_eq!(r#"[1, null]"#, string_array.value(1));
3026+
assert_eq!(r#"[null]"#, string_array.value(2));
3027+
assert_eq!(r#"[]"#, string_array.value(3));
3028+
}
29313029
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
116116

117117
(fromType, toType) match {
118118
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
119+
case (dt: ArrayType, DataTypes.StringType) =>
120+
isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
119121
case (dt: ArrayType, dt1: ArrayType) =>
120122
isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode)
121123
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.comet
2121

2222
import java.io.File
2323

24+
import scala.collection.mutable.ListBuffer
2425
import scala.util.Random
2526
import scala.util.matching.Regex
2627

@@ -30,10 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.Cast
3031
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3132
import org.apache.spark.sql.functions.col
3233
import org.apache.spark.sql.internal.SQLConf
33-
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType}
34+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
3435

3536
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
3637
import org.apache.comet.expressions.{CometCast, CometEvalMode}
38+
import org.apache.comet.rules.CometScanTypeChecker
3739
import org.apache.comet.serde.Compatible
3840

3941
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
@@ -1046,6 +1048,31 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10461048
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
10471049
}
10481050

1051+
test("cast ArrayType to StringType") {
1052+
val hasIncompatibleType = (dt: DataType) =>
1053+
if (CometConf.COMET_NATIVE_SCAN_IMPL.get() == "auto") {
1054+
true
1055+
} else {
1056+
!CometScanTypeChecker(CometConf.COMET_NATIVE_SCAN_IMPL.get())
1057+
.isTypeSupported(dt, "a", ListBuffer.empty)
1058+
}
1059+
Seq(
1060+
BooleanType,
1061+
StringType,
1062+
ByteType,
1063+
IntegerType,
1064+
LongType,
1065+
ShortType,
1066+
// FloatType,
1067+
// DoubleType,
1068+
DecimalType(10, 2),
1069+
DecimalType(38, 18),
1070+
BinaryType).foreach { dt =>
1071+
val input = generateArrays(100, dt)
1072+
castTest(input, StringType, hasIncompatibleType = hasIncompatibleType(input.schema))
1073+
}
1074+
}
1075+
10491076
private def generateFloats(): DataFrame = {
10501077
withNulls(gen.generateFloats(dataSize)).toDF("a")
10511078
}
@@ -1074,6 +1101,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10741101
withNulls(gen.generateLongs(dataSize)).toDF("a")
10751102
}
10761103

1104+
private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
1105+
import scala.collection.JavaConverters._
1106+
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
1107+
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
1108+
}
1109+
10771110
// https://github.com/apache/datafusion-comet/issues/2038
10781111
test("test implicit cast to dictionary with case when and dictionary type") {
10791112
withSQLConf("parquet.enable.dictionary" -> "true") {

spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919

2020
package org.apache.spark.sql
2121

22+
import scala.collection.mutable.ListBuffer
23+
2224
import org.apache.spark.sql.catalyst.TableIdentifier
2325
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2426
import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
2527
import org.apache.spark.sql.catalyst.plans.logical.Project
2628
import org.apache.spark.sql.types.DataTypes
2729

28-
import org.apache.comet.CometFuzzTestBase
30+
import org.apache.comet.{CometConf, CometFuzzTestBase}
2931
import org.apache.comet.expressions.{CometCast, CometEvalMode}
32+
import org.apache.comet.rules.CometScanTypeChecker
3033
import org.apache.comet.serde.Compatible
3134

3235
class CometToPrettyStringSuite extends CometFuzzTestBase {
@@ -47,7 +50,10 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
4750
DataTypes.StringType,
4851
Some(spark.sessionState.conf.sessionLocalTimeZone),
4952
CometEvalMode.TRY) match {
50-
case _: Compatible => checkSparkAnswerAndOperator(result)
53+
case _: Compatible
54+
if CometScanTypeChecker(CometConf.COMET_NATIVE_SCAN_IMPL.get())
55+
.isTypeSupported(field.dataType, field.name, ListBuffer.empty) =>
56+
checkSparkAnswerAndOperator(result)
5157
case _ => checkSparkAnswer(result)
5258
}
5359
}

spark/src/test/spark-4.0/org/apache/spark/sql/CometToPrettyStringSuite.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.spark.sql
2121

22+
import scala.collection.mutable.ListBuffer
23+
2224
import org.apache.spark.sql.catalyst.TableIdentifier
2325
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2426
import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
@@ -28,8 +30,9 @@ import org.apache.spark.sql.internal.SQLConf
2830
import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle
2931
import org.apache.spark.sql.types.DataTypes
3032

31-
import org.apache.comet.CometFuzzTestBase
33+
import org.apache.comet.{CometConf, CometFuzzTestBase}
3234
import org.apache.comet.expressions.{CometCast, CometEvalMode}
35+
import org.apache.comet.rules.CometScanTypeChecker
3336
import org.apache.comet.serde.Compatible
3437

3538
class CometToPrettyStringSuite extends CometFuzzTestBase {
@@ -58,7 +61,10 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
5861
DataTypes.StringType,
5962
Some(spark.sessionState.conf.sessionLocalTimeZone),
6063
CometEvalMode.TRY) match {
61-
case _: Compatible => checkSparkAnswerAndOperator(result)
64+
case _: Compatible
65+
if CometScanTypeChecker(CometConf.COMET_NATIVE_SCAN_IMPL.get())
66+
.isTypeSupported(field.dataType, field.name, ListBuffer.empty) =>
67+
checkSparkAnswerAndOperator(result)
6268
case _ => checkSparkAnswer(result)
6369
}
6470
}

0 commit comments

Comments
 (0)