Skip to content

Commit 15b18c1

Browse files
authored
arrow-cast: Bring back in-order field casting for StructArray (#9007)
# Which issue does this PR close? Closes #9005 # Rationale for this change Not break something in a patch release. # What changes are included in this PR? Bring back in-order casting for structs that have equal field numbers. # Are these changes tested? Yes, the tests that were modified in #8871 were reverted back. # Are there any user-facing changes? It brings back functionality.
1 parent 116ae12 commit 15b18c1

File tree

1 file changed

+93
-54
lines changed

1 file changed

+93
-54
lines changed

arrow-cast/src/cast/mod.rs

Lines changed: 93 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
254254
}
255255

256256
// slow path, we match the fields by name
257-
to_fields.iter().all(|to_field| {
257+
if to_fields.iter().all(|to_field| {
258258
from_fields
259259
.iter()
260260
.find(|from_field| from_field.name() == to_field.name())
@@ -263,7 +263,15 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
263263
// cast kernel will return error.
264264
can_cast_types(from_field.data_type(), to_field.data_type())
265265
})
266-
})
266+
}) {
267+
return true;
268+
}
269+
270+
// if we couldn't match by name, we try to see if they can be matched by position
271+
from_fields
272+
.iter()
273+
.zip(to_fields.iter())
274+
.all(|(f1, f2)| can_cast_types(f1.data_type(), f2.data_type()))
267275
}
268276
(Struct(_), _) => false,
269277
(_, Struct(_)) => false,
@@ -1218,49 +1226,12 @@ pub fn cast_with_options(
12181226
cast_options,
12191227
)
12201228
}
1221-
(Struct(from_fields), Struct(to_fields)) => {
1222-
let array = array.as_struct();
1223-
1224-
// Fast path: if field names are in the same order, we can just zip and cast
1225-
let fields_match_order = from_fields.len() == to_fields.len()
1226-
&& from_fields
1227-
.iter()
1228-
.zip(to_fields.iter())
1229-
.all(|(f1, f2)| f1.name() == f2.name());
1230-
1231-
let fields = if fields_match_order {
1232-
// Fast path: cast columns in order
1233-
array
1234-
.columns()
1235-
.iter()
1236-
.zip(to_fields.iter())
1237-
.map(|(column, field)| {
1238-
cast_with_options(column, field.data_type(), cast_options)
1239-
})
1240-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1241-
} else {
1242-
// Slow path: match fields by name and reorder
1243-
to_fields
1244-
.iter()
1245-
.map(|to_field| {
1246-
let from_field_idx = from_fields
1247-
.iter()
1248-
.position(|from_field| from_field.name() == to_field.name())
1249-
.ok_or_else(|| {
1250-
ArrowError::CastError(format!(
1251-
"Field '{}' not found in source struct",
1252-
to_field.name()
1253-
))
1254-
})?;
1255-
let column = array.column(from_field_idx);
1256-
cast_with_options(column, to_field.data_type(), cast_options)
1257-
})
1258-
.collect::<Result<Vec<ArrayRef>, ArrowError>>()?
1259-
};
1260-
1261-
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
1262-
Ok(Arc::new(array) as ArrayRef)
1263-
}
1229+
(Struct(from_fields), Struct(to_fields)) => cast_struct_to_struct(
1230+
array.as_struct(),
1231+
from_fields.clone(),
1232+
to_fields.clone(),
1233+
cast_options,
1234+
),
12641235
(Struct(_), _) => Err(ArrowError::CastError(format!(
12651236
"Casting from {from_type} to {to_type} not supported"
12661237
))),
@@ -2292,6 +2263,74 @@ pub fn cast_with_options(
22922263
}
22932264
}
22942265

2266+
fn cast_struct_to_struct(
2267+
array: &StructArray,
2268+
from_fields: Fields,
2269+
to_fields: Fields,
2270+
cast_options: &CastOptions,
2271+
) -> Result<ArrayRef, ArrowError> {
2272+
// Fast path: if field names are in the same order, we can just zip and cast
2273+
let fields_match_order = from_fields.len() == to_fields.len()
2274+
&& from_fields
2275+
.iter()
2276+
.zip(to_fields.iter())
2277+
.all(|(f1, f2)| f1.name() == f2.name());
2278+
2279+
let fields = if fields_match_order {
2280+
// Fast path: cast columns in order if their names match
2281+
cast_struct_fields_in_order(array, to_fields.clone(), cast_options)?
2282+
} else {
2283+
let all_fields_match_by_name = to_fields.iter().all(|to_field| {
2284+
from_fields
2285+
.iter()
2286+
.any(|from_field| from_field.name() == to_field.name())
2287+
});
2288+
2289+
if all_fields_match_by_name {
2290+
// Slow path: match fields by name and reorder
2291+
cast_struct_fields_by_name(array, from_fields.clone(), to_fields.clone(), cast_options)?
2292+
} else {
2293+
// Fallback: cast field by field in order
2294+
cast_struct_fields_in_order(array, to_fields.clone(), cast_options)?
2295+
}
2296+
};
2297+
2298+
let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?;
2299+
Ok(Arc::new(array) as ArrayRef)
2300+
}
2301+
2302+
fn cast_struct_fields_by_name(
2303+
array: &StructArray,
2304+
from_fields: Fields,
2305+
to_fields: Fields,
2306+
cast_options: &CastOptions,
2307+
) -> Result<Vec<ArrayRef>, ArrowError> {
2308+
to_fields
2309+
.iter()
2310+
.map(|to_field| {
2311+
let from_field_idx = from_fields
2312+
.iter()
2313+
.position(|from_field| from_field.name() == to_field.name())
2314+
.unwrap(); // safe because we checked above
2315+
let column = array.column(from_field_idx);
2316+
cast_with_options(column, to_field.data_type(), cast_options)
2317+
})
2318+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()
2319+
}
2320+
2321+
fn cast_struct_fields_in_order(
2322+
array: &StructArray,
2323+
to_fields: Fields,
2324+
cast_options: &CastOptions,
2325+
) -> Result<Vec<ArrayRef>, ArrowError> {
2326+
array
2327+
.columns()
2328+
.iter()
2329+
.zip(to_fields.iter())
2330+
.map(|(l, field)| cast_with_options(l, field.data_type(), cast_options))
2331+
.collect::<Result<Vec<ArrayRef>, ArrowError>>()
2332+
}
2333+
22952334
fn cast_from_decimal<D, F>(
22962335
array: &dyn Array,
22972336
base: D::Native,
@@ -10917,11 +10956,11 @@ mod tests {
1091710956
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1091810957
let struct_array = StructArray::from(vec![
1091910958
(
10920-
Arc::new(Field::new("a", DataType::Boolean, false)),
10959+
Arc::new(Field::new("b", DataType::Boolean, false)),
1092110960
boolean.clone() as ArrayRef,
1092210961
),
1092310962
(
10924-
Arc::new(Field::new("b", DataType::Int32, false)),
10963+
Arc::new(Field::new("c", DataType::Int32, false)),
1092510964
int.clone() as ArrayRef,
1092610965
),
1092710966
]);
@@ -10965,11 +11004,11 @@ mod tests {
1096511004
let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None]));
1096611005
let struct_array = StructArray::from(vec![
1096711006
(
10968-
Arc::new(Field::new("a", DataType::Boolean, false)),
11007+
Arc::new(Field::new("b", DataType::Boolean, false)),
1096911008
boolean.clone() as ArrayRef,
1097011009
),
1097111010
(
10972-
Arc::new(Field::new("b", DataType::Int32, true)),
11011+
Arc::new(Field::new("c", DataType::Int32, true)),
1097311012
int.clone() as ArrayRef,
1097411013
),
1097511014
]);
@@ -10999,11 +11038,11 @@ mod tests {
1099911038
let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100]));
1100011039
let struct_array = StructArray::from(vec![
1100111040
(
11002-
Arc::new(Field::new("a", DataType::Boolean, false)),
11041+
Arc::new(Field::new("b", DataType::Boolean, false)),
1100311042
boolean.clone() as ArrayRef,
1100411043
),
1100511044
(
11006-
Arc::new(Field::new("b", DataType::Int32, false)),
11045+
Arc::new(Field::new("c", DataType::Int32, false)),
1100711046
int.clone() as ArrayRef,
1100811047
),
1100911048
]);
@@ -11139,7 +11178,7 @@ mod tests {
1113911178
assert!(result.is_err());
1114011179
assert_eq!(
1114111180
result.unwrap_err().to_string(),
11142-
"Cast error: Field 'b' not found in source struct"
11181+
"Invalid argument error: Incorrect number of arrays for StructArray fields, expected 2 got 1"
1114311182
);
1114411183
}
1114511184

@@ -11196,7 +11235,7 @@ mod tests {
1119611235
}
1119711236

1119811237
#[test]
11199-
fn test_can_cast_struct_with_missing_field() {
11238+
fn test_can_cast_struct_rename_field() {
1120011239
// Test that can_cast_types returns false when target has a field not in source
1120111240
let from_type = DataType::Struct(
1120211241
vec![
@@ -11214,7 +11253,7 @@ mod tests {
1121411253
.into(),
1121511254
);
1121611255

11217-
assert!(!can_cast_types(&from_type, &to_type));
11256+
assert!(can_cast_types(&from_type, &to_type));
1121811257
}
1121911258

1122011259
fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) {

0 commit comments

Comments
 (0)