Skip to content

Commit 25558c2

Browse files
committed
struct casting: remove positional fallback; validate missing non-nullable target fields; update tests
1 parent b7b6e10 commit 25558c2

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

datafusion/common/src/nested_struct.rs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,16 @@ fn cast_struct_column(
6666

6767
if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
6868
let source_fields = source_struct.fields();
69-
let has_overlap = fields_have_name_overlap(source_fields, target_fields);
7069
validate_struct_compatibility(source_fields, target_fields)?;
7170

7271
let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
7372
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
7473
let num_rows = source_col.len();
7574

76-
// Iterate target fields and pick source child either by name (when fields overlap)
77-
// or by position (when there is no name overlap).
78-
for (index, target_child_field) in target_fields.iter().enumerate() {
75+
// Iterate target fields and pick source child by name. Missing fields are filled with NULLs.
76+
for target_child_field in target_fields.iter() {
7977
fields.push(Arc::clone(target_child_field));
80-
81-
// Determine the source child column: by name when overlapping names exist,
82-
// otherwise by position.
83-
let source_child_opt: Option<&ArrayRef> = if has_overlap {
84-
source_struct.column_by_name(target_child_field.name())
85-
} else {
86-
Some(source_struct.column(index))
87-
};
88-
89-
match source_child_opt {
78+
match source_struct.column_by_name(target_child_field.name()) {
9079
Some(source_child_col) => {
9180
let adapted_child =
9281
cast_column(source_child_col, target_child_field, cast_options)
@@ -262,22 +251,14 @@ pub fn validate_struct_compatibility(
262251
source_fields: &[FieldRef],
263252
target_fields: &[FieldRef],
264253
) -> Result<()> {
254+
// Require at least one overlapping field name between source and target
265255
let has_overlap = fields_have_name_overlap(source_fields, target_fields);
266256
if !has_overlap {
267-
if source_fields.len() != target_fields.len() {
268-
return _plan_err!(
269-
"Cannot cast struct with {} fields to {} fields without name overlap; positional mapping is ambiguous",
270-
source_fields.len(),
271-
target_fields.len()
272-
);
273-
}
274-
275-
for (source_field, target_field) in source_fields.iter().zip(target_fields.iter())
276-
{
277-
validate_field_compatibility(source_field, target_field)?;
278-
}
279-
280-
return Ok(());
257+
return _plan_err!(
258+
"Cannot cast struct: at least one field name must match between source and target. Source fields: {:?}, Target fields: {:?}",
259+
source_fields.iter().map(|f| f.name()).collect::<Vec<_>>(),
260+
target_fields.iter().map(|f| f.name()).collect::<Vec<_>>()
261+
);
281262
}
282263

283264
// Check compatibility for each target field
@@ -288,8 +269,15 @@ pub fn validate_struct_compatibility(
288269
.find(|f| f.name() == target_field.name())
289270
{
290271
validate_field_compatibility(source_field, target_field)?;
272+
} else {
273+
// Target field missing from source: allowed only if target field is nullable
274+
if !target_field.is_nullable() {
275+
return _plan_err!(
276+
"Cannot cast struct: target field '{}' is non-nullable but missing from source",
277+
target_field.name()
278+
);
279+
}
291280
}
292-
// Missing fields in source are OK - they'll be filled with nulls
293281
}
294282

295283
// Extra fields in source are OK - they'll be ignored
@@ -572,7 +560,7 @@ mod tests {
572560
let result = validate_struct_compatibility(&source_fields, &target_fields);
573561
assert!(result.is_err());
574562
let error_msg = result.unwrap_err().to_string();
575-
assert!(error_msg.contains("positional mapping is ambiguous"));
563+
assert!(error_msg.contains("at least one field name must match"));
576564
}
577565

578566
#[test]
@@ -869,16 +857,26 @@ mod tests {
869857
vec![field("a", DataType::Int64), field("b", DataType::Utf8)],
870858
);
871859

872-
let result =
873-
cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
874-
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
860+
let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS);
861+
assert!(result.is_err());
862+
let msg = result.unwrap_err().to_string();
863+
assert!(msg.contains("at least one field name must match"));
864+
}
875865

876-
let a_col = get_column_as!(&struct_array, "a", Int64Array);
877-
assert_eq!(a_col.value(0), 10);
878-
assert_eq!(a_col.value(1), 20);
866+
#[test]
867+
fn test_validate_struct_compatibility_missing_non_nullable_target_field() {
868+
// Source has only 'a'
869+
let source_fields = vec![arc_field("a", DataType::Int32)];
879870

880-
let b_col = get_column_as!(&struct_array, "b", StringArray);
881-
assert_eq!(b_col.value(0), "alpha");
882-
assert_eq!(b_col.value(1), "beta");
871+
// Target requires non-nullable 'b'
872+
let target_fields = vec![
873+
arc_field("a", DataType::Int32),
874+
Arc::new(non_null_field("b", DataType::Int32)),
875+
];
876+
877+
let result = validate_struct_compatibility(&source_fields, &target_fields);
878+
assert!(result.is_err());
879+
let msg = result.unwrap_err().to_string();
880+
assert!(msg.contains("non-nullable but missing"));
883881
}
884882
}

0 commit comments

Comments
 (0)