Skip to content

Commit 61cbaa7

Browse files
committed
Implement fallible CastColumnExpr construction
Validate input and target types in CastColumnExpr::new, including struct compatibility checks and castability verification. Update schema rewriting and proto deserialization to accommodate the new constructor behavior. Ensure robust error handling during type casting operations.
1 parent 0e3551c commit 61cbaa7

File tree

6 files changed

+93
-57
lines changed

6 files changed

+93
-57
lines changed

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
471471
Arc::new(physical_field.clone()),
472472
Arc::new(logical_field.clone()),
473473
None,
474-
));
474+
)?);
475475

476476
Ok(Transformed::yes(cast_expr))
477477
}
@@ -684,12 +684,15 @@ mod tests {
684684
println!("Rewritten expression: {result}");
685685

686686
let expected = expressions::BinaryExpr::new(
687-
Arc::new(CastColumnExpr::new(
688-
Arc::new(Column::new("a", 0)),
689-
Arc::new(Field::new("a", DataType::Int32, false)),
690-
Arc::new(Field::new("a", DataType::Int64, false)),
691-
None,
692-
)),
687+
Arc::new(
688+
CastColumnExpr::new(
689+
Arc::new(Column::new("a", 0)),
690+
Arc::new(Field::new("a", DataType::Int32, false)),
691+
Arc::new(Field::new("a", DataType::Int64, false)),
692+
None,
693+
)
694+
.expect("cast column expr"),
695+
),
693696
Operator::Plus,
694697
Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
695698
);
@@ -768,32 +771,35 @@ mod tests {
768771

769772
let result = adapter.rewrite(column_expr).unwrap();
770773

771-
let expected = Arc::new(CastColumnExpr::new(
772-
Arc::new(Column::new("data", 0)),
773-
Arc::new(Field::new(
774-
"data",
775-
DataType::Struct(
776-
vec![
777-
Field::new("id", DataType::Int32, false),
778-
Field::new("name", DataType::Utf8, true),
779-
]
780-
.into(),
781-
),
782-
false,
783-
)),
784-
Arc::new(Field::new(
785-
"data",
786-
DataType::Struct(
787-
vec![
788-
Field::new("id", DataType::Int64, false),
789-
Field::new("name", DataType::Utf8View, true),
790-
]
791-
.into(),
792-
),
793-
false,
794-
)),
795-
None,
796-
)) as Arc<dyn PhysicalExpr>;
774+
let expected = Arc::new(
775+
CastColumnExpr::new(
776+
Arc::new(Column::new("data", 0)),
777+
Arc::new(Field::new(
778+
"data",
779+
DataType::Struct(
780+
vec![
781+
Field::new("id", DataType::Int32, false),
782+
Field::new("name", DataType::Utf8, true),
783+
]
784+
.into(),
785+
),
786+
false,
787+
)),
788+
Arc::new(Field::new(
789+
"data",
790+
DataType::Struct(
791+
vec![
792+
Field::new("id", DataType::Int64, false),
793+
Field::new("name", DataType::Utf8View, true),
794+
]
795+
.into(),
796+
),
797+
false,
798+
)),
799+
None,
800+
)
801+
.expect("cast column expr"),
802+
) as Arc<dyn PhysicalExpr>;
797803

798804
assert_eq!(result.to_string(), expected.to_string());
799805
}

datafusion/physical-expr/src/equivalence/properties/dependency.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ mod tests {
521521
input_field,
522522
target_field,
523523
None,
524-
)) as Arc<dyn PhysicalExpr>;
524+
)?) as Arc<dyn PhysicalExpr>;
525525

526526
let proj_exprs = vec![
527527
(Arc::clone(&col_a), "a".to_string()),

datafusion/physical-expr/src/expressions/cast_column.rs

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
2020
use crate::physical_expr::PhysicalExpr;
2121
use arrow::{
22-
compute::CastOptions,
22+
compute::{CastOptions, can_cast_types},
2323
datatypes::{DataType, FieldRef, Schema},
2424
record_batch::RecordBatch,
2525
};
2626
use datafusion_common::{
27-
Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column,
27+
Result, ScalarValue, format::DEFAULT_CAST_OPTIONS,
28+
nested_struct::{cast_column, validate_struct_compatibility},
29+
plan_err,
2830
};
2931
use datafusion_expr_common::columnar_value::ColumnarValue;
3032
use std::{
@@ -85,13 +87,45 @@ impl CastColumnExpr {
8587
input_field: FieldRef,
8688
target_field: FieldRef,
8789
cast_options: Option<CastOptions<'static>>,
88-
) -> Self {
89-
Self {
90+
) -> Result<Self> {
91+
let input_schema = Schema::new(vec![input_field.as_ref().clone()]);
92+
let expr_data_type = expr.data_type(&input_schema)?;
93+
if input_field.data_type() != &expr_data_type {
94+
return plan_err!(
95+
"CastColumnExpr input field data type '{}' does not match expression data type '{}'",
96+
input_field.data_type(),
97+
expr_data_type
98+
);
99+
}
100+
101+
match (input_field.data_type(), target_field.data_type()) {
102+
(DataType::Struct(source_fields), DataType::Struct(target_fields)) => {
103+
validate_struct_compatibility(source_fields, target_fields)?;
104+
}
105+
(_, DataType::Struct(_)) => {
106+
return plan_err!(
107+
"CastColumnExpr cannot cast non-struct input '{}' to struct target '{}'",
108+
input_field.data_type(),
109+
target_field.data_type()
110+
);
111+
}
112+
_ => {
113+
if !can_cast_types(input_field.data_type(), target_field.data_type()) {
114+
return plan_err!(
115+
"CastColumnExpr cannot cast input type '{}' to target type '{}'",
116+
input_field.data_type(),
117+
target_field.data_type()
118+
);
119+
}
120+
}
121+
}
122+
123+
Ok(Self {
90124
expr,
91125
input_field,
92126
target_field,
93127
cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
94-
}
128+
})
95129
}
96130

97131
/// The expression that produces the value to be cast.
@@ -179,7 +213,7 @@ impl PhysicalExpr for CastColumnExpr {
179213
Arc::clone(&self.input_field),
180214
Arc::clone(&self.target_field),
181215
Some(self.cast_options.clone()),
182-
)))
216+
)?))
183217
}
184218

185219
fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -224,7 +258,7 @@ mod tests {
224258
Arc::new(input_field.clone()),
225259
Arc::new(target_field.clone()),
226260
None,
227-
);
261+
)?;
228262

229263
let result = expr.evaluate(&batch)?;
230264
let ColumnarValue::Array(array) = result else {
@@ -278,7 +312,7 @@ mod tests {
278312
Arc::new(input_field.clone()),
279313
Arc::new(target_field.clone()),
280314
None,
281-
);
315+
)?;
282316

283317
let result = expr.evaluate(&batch)?;
284318
let ColumnarValue::Array(array) = result else {
@@ -348,7 +382,7 @@ mod tests {
348382
Arc::new(outer_field.clone()),
349383
Arc::new(target_field.clone()),
350384
None,
351-
);
385+
)?;
352386

353387
let result = expr.evaluate(&batch)?;
354388
let ColumnarValue::Array(array) = result else {
@@ -399,7 +433,7 @@ mod tests {
399433
Arc::new(input_field.clone()),
400434
Arc::new(target_field.clone()),
401435
None,
402-
);
436+
)?;
403437

404438
let batch = RecordBatch::new_empty(Arc::clone(&schema));
405439
let result = expr.evaluate(&batch)?;

datafusion/physical-expr/src/intervals/utils.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,10 @@ mod tests {
213213
let target_field = Arc::new(Field::new("a", DataType::Int64, true));
214214

215215
let column_expr = col("a", &schema).unwrap();
216-
let cast_expr = Arc::new(CastColumnExpr::new(
217-
column_expr,
218-
input_field,
219-
target_field,
220-
None,
221-
)) as Arc<dyn PhysicalExpr>;
216+
let cast_expr = Arc::new(
217+
CastColumnExpr::new(column_expr, input_field, target_field, None)
218+
.expect("cast column expr"),
219+
) as Arc<dyn PhysicalExpr>;
222220

223221
assert!(check_support(&cast_expr, &schema));
224222
}

datafusion/physical-expr/src/simplifier/unwrap_cast.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,10 @@ mod tests {
221221

222222
// Create: cast_column(c1 as INT64) > INT64(10)
223223
let column_expr = col("c1", &schema).unwrap();
224-
let cast_expr = Arc::new(CastColumnExpr::new(
225-
column_expr,
226-
input_field,
227-
target_field,
228-
None,
229-
));
224+
let cast_expr = Arc::new(
225+
CastColumnExpr::new(column_expr, input_field, target_field, None)
226+
.expect("cast column expr"),
227+
);
230228
let literal_expr = lit(10i64);
231229
let binary_expr =
232230
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ pub fn parse_physical_expr(
375375
Arc::new(Field::try_from(input_field)?),
376376
Arc::new(Field::try_from(target_field)?),
377377
cast_options,
378-
))
378+
)?)
379379
}
380380
ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
381381
parse_required_physical_expr(

0 commit comments

Comments
 (0)