Skip to content

Commit 5ef668c

Browse files
committed
Refactor: Remove CastColumnExpr
1 parent 6746007 commit 5ef668c

File tree

4 files changed

+275
-438
lines changed

4 files changed

+275
-438
lines changed

datafusion/common/src/nested_struct.rs

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,31 +54,50 @@ fn cast_struct_column(
5454
target_fields: &[Arc<Field>],
5555
cast_options: &CastOptions,
5656
) -> Result<ArrayRef> {
57+
if source_col.data_type().is_null() {
58+
return Ok(new_null_array(
59+
&Struct(target_fields.to_vec().into()),
60+
source_col.len(),
61+
));
62+
}
63+
5764
if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
5865
validate_struct_compatibility(source_struct.fields(), target_fields)?;
5966

6067
let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
6168
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
6269
let num_rows = source_col.len();
70+
let source_fields = source_struct.fields();
6371

64-
for target_child_field in target_fields {
72+
for (idx, target_child_field) in target_fields.iter().enumerate() {
6573
fields.push(Arc::clone(target_child_field));
66-
match source_struct.column_by_name(target_child_field.name()) {
67-
Some(source_child_col) => {
68-
let adapted_child =
69-
cast_column(source_child_col, target_child_field, cast_options)
70-
.map_err(|e| {
71-
e.context(format!(
72-
"While casting struct field '{}'",
73-
target_child_field.name()
74-
))
75-
})?;
76-
arrays.push(adapted_child);
77-
}
78-
None => {
79-
arrays.push(new_null_array(target_child_field.data_type(), num_rows));
80-
}
81-
}
74+
let source_child_col = source_struct
75+
.column_by_name(target_child_field.name())
76+
.map(Arc::clone)
77+
.or_else(|| {
78+
source_fields.get(idx).and_then(|field| {
79+
if is_positional_field(field, idx) {
80+
source_struct.columns().get(idx).cloned()
81+
} else {
82+
None
83+
}
84+
})
85+
});
86+
87+
let Some(source_child_col) = source_child_col else {
88+
arrays.push(new_null_array(target_child_field.data_type(), num_rows));
89+
continue;
90+
};
91+
92+
let adapted_child =
93+
cast_column(&source_child_col, target_child_field, cast_options)
94+
.map_err(|e| {
95+
e.context(format!(
96+
"While casting struct field '{}'",
97+
target_child_field.name()
98+
))
99+
})?;
100+
arrays.push(adapted_child);
82101
}
83102

84103
let struct_array =
@@ -205,12 +224,22 @@ pub fn validate_struct_compatibility(
205224
target_fields: &[FieldRef],
206225
) -> Result<()> {
207226
// Check compatibility for each target field
208-
for target_field in target_fields {
227+
for (idx, target_field) in target_fields.iter().enumerate() {
209228
// Look for matching field in source by name
210-
if let Some(source_field) = source_fields
229+
let source_field = source_fields
211230
.iter()
212231
.find(|f| f.name() == target_field.name())
213-
{
232+
.or_else(|| {
233+
source_fields.get(idx).and_then(|field| {
234+
if is_positional_field(field, idx) {
235+
Some(field)
236+
} else {
237+
None
238+
}
239+
})
240+
});
241+
242+
if let Some(source_field) = source_field {
214243
// Ensure nullability is compatible. It is invalid to cast a nullable
215244
// source field to a non-nullable target field as this may discard
216245
// null values.
@@ -249,6 +278,17 @@ pub fn validate_struct_compatibility(
249278
Ok(())
250279
}
251280

281+
fn is_positional_field(field: &FieldRef, idx: usize) -> bool {
282+
field.name().is_empty()
283+
|| field
284+
.name()
285+
.as_str()
286+
.strip_prefix('c')
287+
.and_then(|suffix| suffix.parse::<usize>().ok())
288+
.map(|n| n == idx)
289+
.unwrap_or(false)
290+
}
291+
252292
#[cfg(test)]
253293
mod tests {
254294

datafusion/physical-expr/src/expressions/cast.rs

Lines changed: 215 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ use arrow::compute::{can_cast_types, CastOptions};
2626
use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema};
2727
use arrow::record_batch::RecordBatch;
2828
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
29-
use datafusion_common::{not_impl_err, Result};
29+
use datafusion_common::{
30+
nested_struct::{cast_column, validate_struct_compatibility},
31+
not_impl_err, Result, ScalarValue,
32+
};
3033
use datafusion_expr_common::columnar_value::ColumnarValue;
3134
use datafusion_expr_common::interval_arithmetic::Interval;
3235
use datafusion_expr_common::sort_properties::ExprProperties;
@@ -41,6 +44,18 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions {
4144
format_options: DEFAULT_FORMAT_OPTIONS,
4245
};
4346

47+
fn has_positional_fields(fields: &[FieldRef]) -> bool {
48+
fields.iter().enumerate().any(|(idx, f)| {
49+
f.name().is_empty()
50+
|| f.name()
51+
.as_str()
52+
.strip_prefix('c')
53+
.and_then(|suffix| suffix.parse::<usize>().ok())
54+
.map(|n| n == idx)
55+
.unwrap_or(false)
56+
})
57+
}
58+
4459
/// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast
4560
#[derive(Debug, Clone, Eq)]
4661
pub struct CastExpr {
@@ -138,12 +153,48 @@ impl PhysicalExpr for CastExpr {
138153
}
139154

140155
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
141-
self.expr.nullable(input_schema)
156+
if matches!(self.cast_type, Struct(_)) {
157+
Ok(self.return_field(input_schema)?.is_nullable())
158+
} else {
159+
self.expr.nullable(input_schema)
160+
}
142161
}
143162

144163
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
145164
let value = self.expr.evaluate(batch)?;
146-
value.cast_to(&self.cast_type, Some(&self.cast_options))
165+
let Struct(target_fields) = &self.cast_type else {
166+
return value.cast_to(&self.cast_type, Some(&self.cast_options));
167+
};
168+
let Struct(source_fields) = self.expr.data_type(batch.schema().as_ref())? else {
169+
return value.cast_to(&self.cast_type, Some(&self.cast_options));
170+
};
171+
172+
let use_struct_cast = target_fields.len() > source_fields.len()
173+
|| has_positional_fields(&source_fields)
174+
|| has_positional_fields(target_fields)
175+
|| target_fields
176+
.iter()
177+
.any(|t| source_fields.iter().all(|s| s.name() != t.name()));
178+
179+
if !use_struct_cast || source_fields == *target_fields {
180+
return value.cast_to(&self.cast_type, Some(&self.cast_options));
181+
}
182+
183+
let target_field = self.return_field(batch.schema().as_ref())?;
184+
match value {
185+
ColumnarValue::Array(array) => {
186+
let casted =
187+
cast_column(&array, target_field.as_ref(), &self.cast_options)?;
188+
Ok(ColumnarValue::Array(casted))
189+
}
190+
ColumnarValue::Scalar(scalar) => {
191+
let as_array = scalar.to_array_of_size(1)?;
192+
let casted =
193+
cast_column(&as_array, target_field.as_ref(), &self.cast_options)?;
194+
let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;
195+
Ok(ColumnarValue::Scalar(result))
196+
}
197+
}
147198
}
148199

149200
fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
@@ -229,6 +280,13 @@ pub fn cast_with_options(
229280
let expr_type = expr.data_type(input_schema)?;
230281
if expr_type == cast_type {
231282
Ok(Arc::clone(&expr))
283+
} else if let Struct(target_fields) = &cast_type {
284+
if let Struct(source_fields) = expr_type {
285+
validate_struct_compatibility(&source_fields, target_fields)?;
286+
} else if expr_type != Null {
287+
return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}");
288+
}
289+
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
232290
} else if can_cast_types(&expr_type, &cast_type) {
233291
Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options)))
234292
} else {
@@ -252,16 +310,17 @@ pub fn cast(
252310
mod tests {
253311
use super::*;
254312

255-
use crate::expressions::column::col;
313+
use crate::expressions::{column::col, Column, Literal};
256314

257315
use arrow::{
258316
array::{
259-
Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
260-
Int64Array, Int8Array, StringArray, Time64NanosecondArray,
261-
TimestampNanosecondArray, UInt32Array,
317+
Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array,
318+
Int16Array, Int32Array, Int64Array, Int8Array, StringArray, StructArray,
319+
Time64NanosecondArray, TimestampNanosecondArray, UInt32Array,
262320
},
263321
datatypes::*,
264322
};
323+
use datafusion_common::cast::{as_int64_array, as_string_array, as_uint8_array};
265324
use datafusion_physical_expr_common::physical_expr::fmt_sql;
266325
use insta::assert_snapshot;
267326

@@ -809,4 +868,153 @@ mod tests {
809868

810869
Ok(())
811870
}
871+
872+
fn make_schema(field: &Field) -> SchemaRef {
873+
Arc::new(Schema::new(vec![field.clone()]))
874+
}
875+
876+
fn make_struct_array(fields: Fields, arrays: Vec<ArrayRef>) -> StructArray {
877+
StructArray::new(fields, arrays, None)
878+
}
879+
880+
/// Casts one struct array to another with different fields.
881+
fn cast_struct_array(
882+
input: StructArray,
883+
target_type: &DataType,
884+
) -> Result<StructArray> {
885+
let batch = RecordBatch::try_from_iter(vec![("s", Arc::new(input) as ArrayRef)])?;
886+
let column = Arc::new(Column::new_with_schema("s", batch.schema().as_ref())?);
887+
let expr = CastExpr::new(column, target_type.clone(), Some(DEFAULT_CAST_OPTIONS));
888+
889+
let result = expr.evaluate(&batch)?;
890+
let ColumnarValue::Array(array) = result else {
891+
panic!("expected array");
892+
};
893+
let struct_array = array
894+
.as_any()
895+
.downcast_ref::<StructArray>()
896+
.expect("struct array");
897+
Ok(struct_array.clone())
898+
}
899+
900+
/// Ensures struct casts fill missing target fields with nulls and reorder correctly.
901+
/// Input: { "a": [1, null], "b": ["alpha", "beta"] }
902+
/// Output: { "a": [1, null], "c": [null, null] }
903+
#[test]
904+
fn cast_struct_array_missing_child() -> Result<()> {
905+
let source_a = Arc::new(Field::new("a", Int32, true));
906+
let source_b = Arc::new(Field::new("b", Utf8, true));
907+
908+
let struct_array = make_struct_array(
909+
vec![source_a, source_b].into(),
910+
vec![
911+
Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef,
912+
Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")]))
913+
as ArrayRef,
914+
],
915+
);
916+
917+
let target_a = Arc::new(Field::new("a", Int64, true));
918+
let target_c = Arc::new(Field::new("c", Utf8, true));
919+
let target_type = Struct(Fields::from(vec![target_a, target_c]));
920+
921+
let output_array = cast_struct_array(struct_array, &target_type)?;
922+
923+
let cast_a = as_int64_array(output_array.column_by_name("a").unwrap().as_ref())?;
924+
assert_eq!(cast_a.value(0), 1);
925+
assert!(cast_a.is_null(1));
926+
927+
let cast_c = as_string_array(output_array.column_by_name("c").unwrap().as_ref())?;
928+
assert!(cast_c.is_null(0));
929+
assert!(cast_c.is_null(1));
930+
Ok(())
931+
}
932+
933+
/// Verifies nested struct casts recurse through multiple levels preserving
934+
/// values and adding null placeholders.
935+
///
936+
/// Input: { "root": { "inner": { "x": [7, null] } } }
937+
/// Output: { "root": { "inner": { "x": [7, null], "y": [null, null] } } }
938+
#[test]
939+
fn cast_nested_struct_array() -> Result<()> {
940+
let inner_source_fields = Fields::from([Arc::new(Field::new("x", Int32, true))]);
941+
942+
let inner_source = Field::new_struct("inner", inner_source_fields.clone(), true);
943+
944+
let inner_target_fields: Fields = vec![
945+
Arc::new(Field::new("x", Int64, true)),
946+
Arc::new(Field::new("y", Boolean, true)),
947+
]
948+
.into();
949+
let inner_target = Field::new("inner", Struct(inner_target_fields.clone()), true);
950+
let target_type = Struct(vec![Arc::new(inner_target.clone())].into());
951+
952+
let inner_struct = make_struct_array(
953+
inner_source_fields.clone(),
954+
vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef],
955+
);
956+
let outer_struct = make_struct_array(
957+
vec![Arc::new(inner_source.clone())].into(),
958+
vec![Arc::new(inner_struct) as ArrayRef],
959+
);
960+
let output_array = cast_struct_array(outer_struct, &target_type)?;
961+
962+
let inner = output_array
963+
.column_by_name("inner")
964+
.unwrap()
965+
.as_any()
966+
.downcast_ref::<StructArray>()
967+
.expect("inner struct");
968+
let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?;
969+
assert_eq!(x.value(0), 7);
970+
assert!(x.is_null(1));
971+
let y = inner.column_by_name("y").unwrap();
972+
let y = y
973+
.as_any()
974+
.downcast_ref::<BooleanArray>()
975+
.expect("boolean array");
976+
assert!(y.is_null(0));
977+
assert!(y.is_null(1));
978+
Ok(())
979+
}
980+
981+
#[test]
982+
// Confirms struct casting works for scalars by casting through array form and back to ScalarValue.
983+
fn cast_struct_scalar() -> Result<()> {
984+
let source_field = Field::new("a", Int32, true);
985+
let input_field = Field::new(
986+
"s",
987+
Struct(vec![Arc::new(source_field.clone())].into()),
988+
true,
989+
);
990+
let target_field = Field::new(
991+
"s",
992+
Struct(vec![Arc::new(Field::new("a", UInt8, true))].into()),
993+
true,
994+
);
995+
996+
let schema = make_schema(&input_field);
997+
let scalar_struct = StructArray::new(
998+
vec![Arc::new(source_field.clone())].into(),
999+
vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef],
1000+
None,
1001+
);
1002+
let literal =
1003+
Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct))));
1004+
let expr = CastExpr::new(
1005+
literal,
1006+
target_field.data_type().clone(),
1007+
Some(DEFAULT_CAST_OPTIONS),
1008+
);
1009+
1010+
let batch = RecordBatch::new_empty(Arc::clone(&schema));
1011+
let result = expr.evaluate(&batch)?;
1012+
let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else {
1013+
panic!("expected struct scalar");
1014+
};
1015+
let casted = array.column_by_name("a").unwrap();
1016+
let casted = as_uint8_array(casted.as_ref())?;
1017+
assert_eq!(casted.value(0), 9);
1018+
Ok(())
1019+
}
8121020
}

0 commit comments

Comments
 (0)