|
17 | 17 |
|
18 | 18 | //! Physical expression for struct-aware casting of columns. |
19 | 19 |
|
20 | | -use crate::physical_expr::PhysicalExpr; |
| 20 | +use crate::{expressions::Column, physical_expr::PhysicalExpr}; |
21 | 21 | use arrow::{ |
22 | 22 | compute::{CastOptions, can_cast_types}, |
23 | 23 | datatypes::{DataType, FieldRef, Schema}, |
@@ -112,6 +112,21 @@ impl CastColumnExpr { |
112 | 112 | expr_data_type |
113 | 113 | ); |
114 | 114 | } |
| 115 | + if let Some(column) = expr.as_any().downcast_ref::<Column>() { |
| 116 | + let schema_field = input_schema.field(column.index()); |
| 117 | + if schema_field.name() != input_field.name() |
| 118 | + || schema_field.data_type() != input_field.data_type() |
| 119 | + { |
| 120 | + return plan_err!( |
| 121 | + "CastColumnExpr input field '{}' (type '{}') does not match schema field '{}' (type '{}') at index {}", |
| 122 | + input_field.name(), |
| 123 | + input_field.data_type(), |
| 124 | + schema_field.name(), |
| 125 | + schema_field.data_type(), |
| 126 | + column.index() |
| 127 | + ); |
| 128 | + } |
| 129 | + } |
115 | 130 |
|
116 | 131 | match (input_field.data_type(), target_field.data_type()) { |
117 | 132 | (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { |
@@ -509,4 +524,28 @@ mod tests { |
509 | 524 | assert_eq!(casted.value(0), 9); |
510 | 525 | Ok(()) |
511 | 526 | } |
| 527 | + |
| 528 | + #[test] |
| 529 | + fn cast_column_schema_mismatch() { |
| 530 | + let input_field = Field::new("a", DataType::Int32, true); |
| 531 | + let target_field = Field::new("a", DataType::Int32, true); |
| 532 | + let schema = Arc::new(Schema::new(vec![ |
| 533 | + input_field.clone(), |
| 534 | + Field::new("b", DataType::Int32, true), |
| 535 | + ])); |
| 536 | + |
| 537 | + let column = Arc::new(Column::new("b", 1)); |
| 538 | + let err = CastColumnExpr::new_with_schema( |
| 539 | + column, |
| 540 | + Arc::new(input_field), |
| 541 | + Arc::new(target_field), |
| 542 | + None, |
| 543 | + schema, |
| 544 | + ) |
| 545 | + .expect_err("expected mismatched input field error"); |
| 546 | + |
| 547 | + assert!(err |
| 548 | + .to_string() |
| 549 | + .contains("does not match schema field")); |
| 550 | + } |
512 | 551 | } |
0 commit comments