|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use crate::error::{DataFusionError, Result, _plan_err}; |
| 19 | +use arrow::{ |
| 20 | + array::{new_null_array, Array, ArrayRef, StructArray}, |
| 21 | + compute::cast, |
| 22 | + datatypes::{DataType::Struct, Field, FieldRef}, |
| 23 | +}; |
| 24 | +use std::sync::Arc; |
| 25 | + |
| 26 | +/// Cast a struct column to match target struct fields, handling nested structs recursively. |
| 27 | +/// |
| 28 | +/// This function implements struct-to-struct casting with the assumption that **structs should |
| 29 | +/// always be allowed to cast to other structs**. However, the source column must already be |
| 30 | +/// a struct type - non-struct sources will result in an error. |
| 31 | +/// |
| 32 | +/// ## Field Matching Strategy |
| 33 | +/// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) |
| 34 | +/// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type |
| 35 | +/// - **Missing Fields**: Target fields not present in the source are filled with null values |
| 36 | +/// - **Extra Fields**: Source fields not present in the target are ignored |
| 37 | +/// |
| 38 | +/// ## Nested Struct Handling |
| 39 | +/// - Nested structs are handled recursively using the same casting rules |
| 40 | +/// - Each level of nesting follows the same field matching and null-filling strategy |
| 41 | +/// - This allows for complex struct transformations while maintaining data integrity |
| 42 | +/// |
| 43 | +/// # Arguments |
| 44 | +/// * `source_col` - The source array to cast (must be a struct array) |
| 45 | +/// * `target_fields` - The target struct field definitions to cast to |
| 46 | +/// |
| 47 | +/// # Returns |
| 48 | +/// A `Result<ArrayRef>` containing the cast struct array |
| 49 | +/// |
| 50 | +/// # Errors |
| 51 | +/// Returns a `DataFusionError::Plan` if the source column is not a struct type |
| 52 | +fn cast_struct_column( |
| 53 | + source_col: &ArrayRef, |
| 54 | + target_fields: &[Arc<Field>], |
| 55 | +) -> Result<ArrayRef> { |
| 56 | + if let Some(struct_array) = source_col.as_any().downcast_ref::<StructArray>() { |
| 57 | + let mut children: Vec<(Arc<Field>, Arc<dyn Array>)> = Vec::new(); |
| 58 | + let num_rows = source_col.len(); |
| 59 | + |
| 60 | + for target_child_field in target_fields { |
| 61 | + let field_arc = Arc::clone(target_child_field); |
| 62 | + match struct_array.column_by_name(target_child_field.name()) { |
| 63 | + Some(source_child_col) => { |
| 64 | + let adapted_child = |
| 65 | + cast_column(source_child_col, target_child_field)?; |
| 66 | + children.push((field_arc, adapted_child)); |
| 67 | + } |
| 68 | + None => { |
| 69 | + children.push(( |
| 70 | + field_arc, |
| 71 | + new_null_array(target_child_field.data_type(), num_rows), |
| 72 | + )); |
| 73 | + } |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + let struct_array = StructArray::from(children); |
| 78 | + Ok(Arc::new(struct_array)) |
| 79 | + } else { |
| 80 | + // Return error if source is not a struct type |
| 81 | + Err(DataFusionError::Plan(format!( |
| 82 | + "Cannot cast column of type {:?} to struct type. Source must be a struct to cast to struct.", |
| 83 | + source_col.data_type() |
| 84 | + ))) |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +/// Cast a column to match the target field type, with special handling for nested structs. |
| 89 | +/// |
| 90 | +/// This function serves as the main entry point for column casting operations. For struct |
| 91 | +/// types, it enforces that **only struct columns can be cast to struct types**. |
| 92 | +/// |
| 93 | +/// ## Casting Behavior |
| 94 | +/// - **Struct Types**: Delegates to `cast_struct_column` for struct-to-struct casting only |
| 95 | +/// - **Non-Struct Types**: Uses Arrow's standard `cast` function for primitive type conversions |
| 96 | +/// |
| 97 | +/// ## Struct Casting Requirements |
| 98 | +/// The struct casting logic requires that the source column must already be a struct type. |
| 99 | +/// This makes the function useful for: |
| 100 | +/// - Schema evolution scenarios where struct layouts change over time |
| 101 | +/// - Data migration between different struct schemas |
| 102 | +/// - Type-safe data processing pipelines that maintain struct type integrity |
| 103 | +/// |
| 104 | +/// # Arguments |
| 105 | +/// * `source_col` - The source array to cast |
| 106 | +/// * `target_field` - The target field definition (including type and metadata) |
| 107 | +/// |
| 108 | +/// # Returns |
| 109 | +/// A `Result<ArrayRef>` containing the cast array |
| 110 | +/// |
| 111 | +/// # Errors |
| 112 | +/// Returns an error if: |
| 113 | +/// - Attempting to cast a non-struct column to a struct type |
| 114 | +/// - Arrow's cast function fails for non-struct types |
| 115 | +/// - Memory allocation fails during struct construction |
| 116 | +/// - Invalid data type combinations are encountered |
| 117 | +pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result<ArrayRef> { |
| 118 | + match target_field.data_type() { |
| 119 | + Struct(target_fields) => cast_struct_column(source_col, target_fields), |
| 120 | + _ => Ok(cast(source_col, target_field.data_type())?), |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +/// Validates compatibility between source and target struct fields for casting operations. |
| 125 | +/// |
| 126 | +/// This function implements comprehensive struct compatibility checking by examining: |
| 127 | +/// - Field name matching between source and target structs |
| 128 | +/// - Type castability for each matching field (including recursive struct validation) |
| 129 | +/// - Proper handling of missing fields (target fields not in source are allowed - filled with nulls) |
| 130 | +/// - Proper handling of extra fields (source fields not in target are allowed - ignored) |
| 131 | +/// |
| 132 | +/// # Compatibility Rules |
| 133 | +/// - **Field Matching**: Fields are matched by name (case-sensitive) |
| 134 | +/// - **Missing Target Fields**: Allowed - will be filled with null values during casting |
| 135 | +/// - **Extra Source Fields**: Allowed - will be ignored during casting |
| 136 | +/// - **Type Compatibility**: Each matching field must be castable using Arrow's type system |
| 137 | +/// - **Nested Structs**: Recursively validates nested struct compatibility |
| 138 | +/// |
| 139 | +/// # Arguments |
| 140 | +/// * `source_fields` - Fields from the source struct type |
| 141 | +/// * `target_fields` - Fields from the target struct type |
| 142 | +/// |
| 143 | +/// # Returns |
| 144 | +/// * `Ok(true)` if the structs are compatible for casting |
| 145 | +/// * `Err(DataFusionError)` with detailed error message if incompatible |
| 146 | +/// |
| 147 | +/// # Examples |
| 148 | +/// ```text |
| 149 | +/// // Compatible: source has extra field, target has missing field |
| 150 | +/// // Source: {a: i32, b: string, c: f64} |
| 151 | +/// // Target: {a: i64, d: bool} |
| 152 | +/// // Result: Ok(true) - 'a' can cast i32->i64, 'b','c' ignored, 'd' filled with nulls |
| 153 | +/// |
| 154 | +/// // Incompatible: matching field has incompatible types |
| 155 | +/// // Source: {a: string} |
| 156 | +/// // Target: {a: binary} |
| 157 | +/// // Result: Err(...) - string cannot cast to binary |
| 158 | +/// ``` |
| 159 | +pub fn validate_struct_compatibility( |
| 160 | + source_fields: &[FieldRef], |
| 161 | + target_fields: &[FieldRef], |
| 162 | +) -> Result<bool> { |
| 163 | + // Check compatibility for each target field |
| 164 | + for target_field in target_fields { |
| 165 | + // Look for matching field in source by name |
| 166 | + if let Some(source_field) = source_fields |
| 167 | + .iter() |
| 168 | + .find(|f| f.name() == target_field.name()) |
| 169 | + { |
| 170 | + // Check if the matching field types are compatible |
| 171 | + match (source_field.data_type(), target_field.data_type()) { |
| 172 | + // Recursively validate nested structs |
| 173 | + (Struct(source_nested), Struct(target_nested)) => { |
| 174 | + validate_struct_compatibility(source_nested, target_nested)?; |
| 175 | + } |
| 176 | + // For non-struct types, use the existing castability check |
| 177 | + _ => { |
| 178 | + if !arrow::compute::can_cast_types( |
| 179 | + source_field.data_type(), |
| 180 | + target_field.data_type(), |
| 181 | + ) { |
| 182 | + return _plan_err!( |
| 183 | + "Cannot cast struct field '{}' from type {:?} to type {:?}", |
| 184 | + target_field.name(), |
| 185 | + source_field.data_type(), |
| 186 | + target_field.data_type() |
| 187 | + ); |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + } |
| 192 | + // Missing fields in source are OK - they'll be filled with nulls |
| 193 | + } |
| 194 | + |
| 195 | + // Extra fields in source are OK - they'll be ignored |
| 196 | + Ok(true) |
| 197 | +} |
| 198 | + |
| 199 | +#[cfg(test)] |
| 200 | +mod tests { |
| 201 | + use super::*; |
| 202 | + use arrow::{ |
| 203 | + array::{Int32Array, Int64Array, StringArray}, |
| 204 | + datatypes::{DataType, Field}, |
| 205 | + }; |
| 206 | + /// Macro to extract and downcast a column from a StructArray |
| 207 | + macro_rules! get_column_as { |
| 208 | + ($struct_array:expr, $column_name:expr, $array_type:ty) => { |
| 209 | + $struct_array |
| 210 | + .column_by_name($column_name) |
| 211 | + .unwrap() |
| 212 | + .as_any() |
| 213 | + .downcast_ref::<$array_type>() |
| 214 | + .unwrap() |
| 215 | + }; |
| 216 | + } |
| 217 | + |
| 218 | + #[test] |
| 219 | + fn test_cast_simple_column() { |
| 220 | + let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; |
| 221 | + let target_field = Field::new("ints", DataType::Int64, true); |
| 222 | + let result = cast_column(&source, &target_field).unwrap(); |
| 223 | + let result = result.as_any().downcast_ref::<Int64Array>().unwrap(); |
| 224 | + assert_eq!(result.len(), 3); |
| 225 | + assert_eq!(result.value(0), 1); |
| 226 | + assert_eq!(result.value(1), 2); |
| 227 | + assert_eq!(result.value(2), 3); |
| 228 | + } |
| 229 | + |
| 230 | + #[test] |
| 231 | + fn test_cast_struct_with_missing_field() { |
| 232 | + let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef; |
| 233 | + let source_struct = StructArray::from(vec![( |
| 234 | + Arc::new(Field::new("a", DataType::Int32, true)), |
| 235 | + Arc::clone(&a_array), |
| 236 | + )]); |
| 237 | + let source_col = Arc::new(source_struct) as ArrayRef; |
| 238 | + |
| 239 | + let target_field = Field::new( |
| 240 | + "s", |
| 241 | + Struct( |
| 242 | + vec![ |
| 243 | + Arc::new(Field::new("a", DataType::Int32, true)), |
| 244 | + Arc::new(Field::new("b", DataType::Utf8, true)), |
| 245 | + ] |
| 246 | + .into(), |
| 247 | + ), |
| 248 | + true, |
| 249 | + ); |
| 250 | + |
| 251 | + let result = cast_column(&source_col, &target_field).unwrap(); |
| 252 | + let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap(); |
| 253 | + assert_eq!(struct_array.fields().len(), 2); |
| 254 | + let a_result = get_column_as!(&struct_array, "a", Int32Array); |
| 255 | + assert_eq!(a_result.value(0), 1); |
| 256 | + assert_eq!(a_result.value(1), 2); |
| 257 | + |
| 258 | + let b_result = get_column_as!(&struct_array, "b", StringArray); |
| 259 | + assert_eq!(b_result.len(), 2); |
| 260 | + assert!(b_result.is_null(0)); |
| 261 | + assert!(b_result.is_null(1)); |
| 262 | + } |
| 263 | + |
| 264 | + #[test] |
| 265 | + fn test_cast_struct_source_not_struct() { |
| 266 | + let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; |
| 267 | + let target_field = Field::new( |
| 268 | + "s", |
| 269 | + Struct(vec![Arc::new(Field::new("a", DataType::Int32, true))].into()), |
| 270 | + true, |
| 271 | + ); |
| 272 | + |
| 273 | + let result = cast_column(&source, &target_field); |
| 274 | + assert!(result.is_err()); |
| 275 | + let error_msg = result.unwrap_err().to_string(); |
| 276 | + assert!(error_msg.contains("Cannot cast column of type")); |
| 277 | + assert!(error_msg.contains("to struct type")); |
| 278 | + assert!(error_msg.contains("Source must be a struct")); |
| 279 | + } |
| 280 | + |
| 281 | + #[test] |
| 282 | + fn test_validate_struct_compatibility_incompatible_types() { |
| 283 | + // Source struct: {field1: Binary, field2: String} |
| 284 | + let source_fields = vec![ |
| 285 | + Arc::new(Field::new("field1", DataType::Binary, true)), |
| 286 | + Arc::new(Field::new("field2", DataType::Utf8, true)), |
| 287 | + ]; |
| 288 | + |
| 289 | + // Target struct: {field1: Int32} |
| 290 | + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; |
| 291 | + |
| 292 | + let result = validate_struct_compatibility(&source_fields, &target_fields); |
| 293 | + assert!(result.is_err()); |
| 294 | + let error_msg = result.unwrap_err().to_string(); |
| 295 | + assert!(error_msg.contains("Cannot cast struct field 'field1'")); |
| 296 | + assert!(error_msg.contains("Binary")); |
| 297 | + assert!(error_msg.contains("Int32")); |
| 298 | + } |
| 299 | + |
| 300 | + #[test] |
| 301 | + fn test_validate_struct_compatibility_compatible_types() { |
| 302 | + // Source struct: {field1: Int32, field2: String} |
| 303 | + let source_fields = vec![ |
| 304 | + Arc::new(Field::new("field1", DataType::Int32, true)), |
| 305 | + Arc::new(Field::new("field2", DataType::Utf8, true)), |
| 306 | + ]; |
| 307 | + |
| 308 | + // Target struct: {field1: Int64} (Int32 can cast to Int64) |
| 309 | + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int64, true))]; |
| 310 | + |
| 311 | + let result = validate_struct_compatibility(&source_fields, &target_fields); |
| 312 | + assert!(result.is_ok()); |
| 313 | + assert!(result.unwrap()); |
| 314 | + } |
| 315 | + |
| 316 | + #[test] |
| 317 | + fn test_validate_struct_compatibility_missing_field_in_source() { |
| 318 | + // Source struct: {field2: String} (missing field1) |
| 319 | + let source_fields = vec![Arc::new(Field::new("field2", DataType::Utf8, true))]; |
| 320 | + |
| 321 | + // Target struct: {field1: Int32} |
| 322 | + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; |
| 323 | + |
| 324 | + // Should be OK - missing fields will be filled with nulls |
| 325 | + let result = validate_struct_compatibility(&source_fields, &target_fields); |
| 326 | + assert!(result.is_ok()); |
| 327 | + assert!(result.unwrap()); |
| 328 | + } |
| 329 | +} |
0 commit comments