Skip to content

Commit 0b4d75e

Browse files
authored
fix: cast_struct_to_struct aligns to Spark behavior (#1879)
1 parent 1d0550f commit 0b4d75e

File tree

3 files changed

+56
-32
lines changed

3 files changed

+56
-32
lines changed

native/core/src/parquet/parquet_support.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ use url::Url;
4444

4545
use super::objectstore;
4646

47+
// This file originates from cast.rs. While developing native scan support and implementing
48+
// SparkSchemaAdapter we observed that Spark's type conversion logic on Parquet reads does not
49+
// always align to the CAST expression's logic, so it was duplicated here to adapt its behavior.
50+
4751
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
4852

4953
static PARQUET_OPTIONS: CastOptions = CastOptions {
@@ -53,7 +57,7 @@ static PARQUET_OPTIONS: CastOptions = CastOptions {
5357
.with_timestamp_format(TIMESTAMP_FORMAT),
5458
};
5559

56-
/// Spark cast options
60+
/// Spark Parquet type conversion options
5761
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
5862
pub struct SparkParquetOptions {
5963
/// Spark evaluation mode
@@ -109,7 +113,7 @@ pub fn spark_parquet_convert(
109113
parquet_options: &SparkParquetOptions,
110114
) -> DataFusionResult<ColumnarValue> {
111115
match arg {
112-
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array(
116+
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(parquet_convert_array(
113117
array,
114118
data_type,
115119
parquet_options,
@@ -119,14 +123,16 @@ pub fn spark_parquet_convert(
119123
// some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it
120124
// here.
121125
let array = scalar.to_array()?;
122-
let scalar =
123-
ScalarValue::try_from_array(&cast_array(array, data_type, parquet_options)?, 0)?;
126+
let scalar = ScalarValue::try_from_array(
127+
&parquet_convert_array(array, data_type, parquet_options)?,
128+
0,
129+
)?;
124130
Ok(ColumnarValue::Scalar(scalar))
125131
}
126132
}
127133
}
128134

129-
fn cast_array(
135+
fn parquet_convert_array(
130136
array: ArrayRef,
131137
to_type: &DataType,
132138
parquet_options: &SparkParquetOptions,
@@ -146,7 +152,7 @@ fn cast_array(
146152

147153
let casted_dictionary = DictionaryArray::<Int32Type>::new(
148154
dict_array.keys().clone(),
149-
cast_array(Arc::clone(dict_array.values()), to_type, parquet_options)?,
155+
parquet_convert_array(Arc::clone(dict_array.values()), to_type, parquet_options)?,
150156
);
151157

152158
let casted_result = match to_type {
@@ -162,15 +168,15 @@ fn cast_array(
162168
// Try Comet specific handlers first, then arrow-rs cast if supported,
163169
// return uncasted data otherwise
164170
match (from_type, to_type) {
165-
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
171+
(Struct(_), Struct(_)) => Ok(parquet_convert_struct_to_struct(
166172
array.as_struct(),
167173
from_type,
168174
to_type,
169175
parquet_options,
170176
)?),
171177
(List(_), List(to_inner_type)) => {
172178
let list_arr: &ListArray = array.as_list();
173-
let cast_field = cast_array(
179+
let cast_field = parquet_convert_array(
174180
Arc::clone(list_arr.values()),
175181
to_inner_type.data_type(),
176182
parquet_options,
@@ -192,7 +198,7 @@ fn cast_array(
192198
))
193199
}
194200
(Map(_, ordered_from), Map(_, ordered_to)) if ordered_from == ordered_to =>
195-
cast_map_values(array.as_map(), to_type, parquet_options, *ordered_to)
201+
parquet_convert_map_to_map(array.as_map(), to_type, parquet_options, *ordered_to)
196202
,
197203
// If Arrow cast supports the cast, delegate the cast to Arrow
198204
_ if can_cast_types(from_type, to_type) => {
@@ -204,7 +210,7 @@ fn cast_array(
204210

205211
/// Cast between struct types based on logic in
206212
/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
207-
fn cast_struct_to_struct(
213+
fn parquet_convert_struct_to_struct(
208214
array: &StructArray,
209215
from_type: &DataType,
210216
to_type: &DataType,
@@ -236,7 +242,7 @@ fn cast_struct_to_struct(
236242
};
237243
if field_name_to_index_map.contains_key(&key) {
238244
let from_index = field_name_to_index_map[&key];
239-
let cast_field = cast_array(
245+
let cast_field = parquet_convert_array(
240246
Arc::clone(array.column(from_index)),
241247
to_fields[i].data_type(),
242248
parquet_options,
@@ -267,8 +273,8 @@ fn cast_struct_to_struct(
267273
}
268274

269275
/// Cast a map type to another map type. The same as arrow-cast except we recursively call our own
270-
/// cast_array
271-
fn cast_map_values(
276+
/// parquet_convert_array
277+
fn parquet_convert_map_to_map(
272278
from: &MapArray,
273279
to_data_type: &DataType,
274280
parquet_options: &SparkParquetOptions,
@@ -283,12 +289,12 @@ fn cast_map_values(
283289
"map is missing value field".to_string(),
284290
))?;
285291

286-
let key_array = cast_array(
292+
let key_array = parquet_convert_array(
287293
Arc::clone(from.keys()),
288294
key_field.data_type(),
289295
parquet_options,
290296
)?;
291-
let value_array = cast_array(
297+
let value_array = parquet_convert_array(
292298
Arc::clone(from.values()),
293299
value_field.data_type(),
294300
parquet_options,

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ use num::{
4949
ToPrimitive,
5050
};
5151
use regex::Regex;
52-
use std::collections::HashMap;
5352
use std::str::FromStr;
5453
use std::{
5554
any::Any,
@@ -1081,22 +1080,23 @@ fn cast_struct_to_struct(
10811080
) -> DataFusionResult<ArrayRef> {
10821081
match (from_type, to_type) {
10831082
(DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
1084-
// TODO some of this logic may be specific to converting Parquet to Spark
1085-
let mut field_name_to_index_map = HashMap::new();
1086-
for (i, field) in from_fields.iter().enumerate() {
1087-
field_name_to_index_map.insert(field.name(), i);
1088-
}
1089-
assert_eq!(field_name_to_index_map.len(), from_fields.len());
1090-
let mut cast_fields: Vec<ArrayRef> = Vec::with_capacity(to_fields.len());
1091-
for i in 0..to_fields.len() {
1092-
let from_index = field_name_to_index_map[to_fields[i].name()];
1093-
let cast_field = cast_array(
1094-
Arc::clone(array.column(from_index)),
1095-
to_fields[i].data_type(),
1096-
cast_options,
1097-
)?;
1098-
cast_fields.push(cast_field);
1099-
}
1083+
let cast_fields: Vec<ArrayRef> = from_fields
1084+
.iter()
1085+
.enumerate()
1086+
.zip(to_fields.iter())
1087+
.map(|((idx, _from), to)| {
1088+
let from_field = Arc::clone(array.column(idx));
1089+
let array_length = from_field.len();
1090+
let cast_result = spark_cast(
1091+
ColumnarValue::from(from_field),
1092+
to.data_type(),
1093+
cast_options,
1094+
)
1095+
.unwrap();
1096+
cast_result.to_array(array_length).unwrap()
1097+
})
1098+
.collect();
1099+
11001100
Ok(Arc::new(StructArray::new(
11011101
to_fields.clone(),
11021102
cast_fields,

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
972972
}
973973
}
974974

975+
test("cast StructType to StructType with different names") {
976+
withTable("tab1") {
977+
sql("""
978+
|CREATE TABLE tab1 (s struct<a: string, b: string>)
979+
|USING parquet
980+
""".stripMargin)
981+
sql("INSERT INTO TABLE tab1 SELECT named_struct('col1','1','col2','2')")
982+
if (usingDataSourceExec) {
983+
checkSparkAnswerAndOperator(
984+
"SELECT CAST(s AS struct<field1:string, field2:string>) AS new_struct FROM tab1")
985+
} else {
986+
// Should just fall back to Spark since non-DataSourceExec scan does not support nested types.
987+
checkSparkAnswer(
988+
"SELECT CAST(s AS struct<field1:string, field2:string>) AS new_struct FROM tab1")
989+
}
990+
}
991+
}
992+
975993
test("cast between decimals with different precision and scale") {
976994
val rowData = Seq(
977995
Row(BigDecimal("12345.6789")),

0 commit comments

Comments
 (0)