Skip to content

Commit daf45f7

Browse files
committed
Implement schema-aware CastColumnExpr constructor
Add a new CastColumnExpr::new_with_schema constructor that accepts and stores the input schema. Document the column-only helper for single-field validation paths. Update CastColumnExpr construction to include full input schemas during schema rewriting and proto parsing, ensuring correct type resolution.
1 parent 61cbaa7 commit daf45f7

File tree

7 files changed

+80
-24
lines changed

7 files changed

+80
-24
lines changed

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,12 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
466466
}
467467
}
468468

469-
let cast_expr = Arc::new(CastColumnExpr::new(
469+
let cast_expr = Arc::new(CastColumnExpr::new_with_schema(
470470
Arc::new(column),
471471
Arc::new(physical_field.clone()),
472472
Arc::new(logical_field.clone()),
473473
None,
474+
Arc::clone(&self.physical_file_schema),
474475
)?);
475476

476477
Ok(Transformed::yes(cast_expr))
@@ -659,8 +660,11 @@ mod tests {
659660
#[test]
660661
fn test_rewrite_multi_column_expr_with_type_cast() {
661662
let (physical_schema, logical_schema) = create_test_schema();
663+
let physical_schema = Arc::new(physical_schema);
664+
let logical_schema = Arc::new(logical_schema);
662665
let factory = DefaultPhysicalExprAdapterFactory;
663-
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
666+
let adapter =
667+
factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
664668

665669
// Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter
666670
let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
@@ -685,11 +689,12 @@ mod tests {
685689

686690
let expected = expressions::BinaryExpr::new(
687691
Arc::new(
688-
CastColumnExpr::new(
692+
CastColumnExpr::new_with_schema(
689693
Arc::new(Column::new("a", 0)),
690694
Arc::new(Field::new("a", DataType::Int32, false)),
691695
Arc::new(Field::new("a", DataType::Int64, false)),
692696
None,
697+
Arc::clone(&physical_schema),
693698
)
694699
.expect("cast column expr"),
695700
),
@@ -765,15 +770,18 @@ mod tests {
765770
false,
766771
)]);
767772

773+
let physical_schema = Arc::new(physical_schema);
774+
let logical_schema = Arc::new(logical_schema);
768775
let factory = DefaultPhysicalExprAdapterFactory;
769-
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
776+
let adapter =
777+
factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
770778
let column_expr = Arc::new(Column::new("data", 0));
771779

772780
let result = adapter.rewrite(column_expr).unwrap();
773781

774782
let expected = Arc::new(
775-
CastColumnExpr::new(
776-
Arc::new(Column::new("data", 0)),
783+
CastColumnExpr::new_with_schema(
784+
Arc::new(Column::new("data", 0)),
777785
Arc::new(Field::new(
778786
"data",
779787
DataType::Struct(
@@ -797,6 +805,7 @@ mod tests {
797805
false,
798806
)),
799807
None,
808+
Arc::clone(&physical_schema),
800809
)
801810
.expect("cast column expr"),
802811
) as Arc<dyn PhysicalExpr>;

datafusion/physical-expr/src/equivalence/properties/dependency.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,12 @@ mod tests {
516516

517517
let input_field = Arc::new(input_schema.field(0).clone());
518518
let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true));
519-
let cast_col = Arc::new(CastColumnExpr::new(
519+
let cast_col = Arc::new(CastColumnExpr::new_with_schema(
520520
Arc::clone(&col_a),
521521
input_field,
522522
target_field,
523523
None,
524+
Arc::clone(&input_schema),
524525
)?) as Arc<dyn PhysicalExpr>;
525526

526527
let proj_exprs = vec![

datafusion/physical-expr/src/expressions/cast_column.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct CastColumnExpr {
5858
target_field: FieldRef,
5959
/// Options forwarded to [`cast_column`].
6060
cast_options: CastOptions<'static>,
61+
/// Schema used to resolve expression data types during construction.
62+
input_schema: Arc<Schema>,
6163
}
6264

6365
// Manually derive `PartialEq`/`Hash` as `Arc<dyn PhysicalExpr>` does not
@@ -82,14 +84,36 @@ impl Hash for CastColumnExpr {
8284

8385
impl CastColumnExpr {
8486
/// Create a new [`CastColumnExpr`].
87+
///
88+
/// This constructor assumes `expr` is a column expression and validates it
89+
/// against a single-field schema derived from `input_field`. If the
90+
/// expression depends on a broader schema (for example, computed
91+
/// expressions), use [`Self::new_with_schema`] instead.
8592
pub fn new(
8693
expr: Arc<dyn PhysicalExpr>,
8794
input_field: FieldRef,
8895
target_field: FieldRef,
8996
cast_options: Option<CastOptions<'static>>,
9097
) -> Result<Self> {
91-
let input_schema = Schema::new(vec![input_field.as_ref().clone()]);
92-
let expr_data_type = expr.data_type(&input_schema)?;
98+
let input_schema = Arc::new(Schema::new(vec![input_field.as_ref().clone()]));
99+
Self::new_with_schema(
100+
expr,
101+
input_field,
102+
target_field,
103+
cast_options,
104+
input_schema,
105+
)
106+
}
107+
108+
/// Create a new [`CastColumnExpr`] using the full input schema.
109+
pub fn new_with_schema(
110+
expr: Arc<dyn PhysicalExpr>,
111+
input_field: FieldRef,
112+
target_field: FieldRef,
113+
cast_options: Option<CastOptions<'static>>,
114+
input_schema: Arc<Schema>,
115+
) -> Result<Self> {
116+
let expr_data_type = expr.data_type(input_schema.as_ref())?;
93117
if input_field.data_type() != &expr_data_type {
94118
return plan_err!(
95119
"CastColumnExpr input field data type '{}' does not match expression data type '{}'",
@@ -125,6 +149,7 @@ impl CastColumnExpr {
125149
input_field,
126150
target_field,
127151
cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
152+
input_schema,
128153
})
129154
}
130155

@@ -208,11 +233,12 @@ impl PhysicalExpr for CastColumnExpr {
208233
) -> Result<Arc<dyn PhysicalExpr>> {
209234
assert_eq!(children.len(), 1);
210235
let child = children.pop().expect("CastColumnExpr child");
211-
Ok(Arc::new(Self::new(
236+
Ok(Arc::new(Self::new_with_schema(
212237
child,
213238
Arc::clone(&self.input_field),
214239
Arc::clone(&self.target_field),
215240
Some(self.cast_options.clone()),
241+
Arc::clone(&self.input_schema),
216242
)?))
217243
}
218244

@@ -253,11 +279,12 @@ mod tests {
253279
let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values])?;
254280

255281
let column = Arc::new(Column::new_with_schema("a", schema.as_ref())?);
256-
let expr = CastColumnExpr::new(
282+
let expr = CastColumnExpr::new_with_schema(
257283
column,
258284
Arc::new(input_field.clone()),
259285
Arc::new(target_field.clone()),
260286
None,
287+
Arc::clone(&schema),
261288
)?;
262289

263290
let result = expr.evaluate(&batch)?;
@@ -307,11 +334,12 @@ mod tests {
307334
)?;
308335

309336
let column = Arc::new(Column::new_with_schema("s", schema.as_ref())?);
310-
let expr = CastColumnExpr::new(
337+
let expr = CastColumnExpr::new_with_schema(
311338
column,
312339
Arc::new(input_field.clone()),
313340
Arc::new(target_field.clone()),
314341
None,
342+
Arc::clone(&schema),
315343
)?;
316344

317345
let result = expr.evaluate(&batch)?;
@@ -377,11 +405,12 @@ mod tests {
377405
)?;
378406

379407
let column = Arc::new(Column::new_with_schema("root", schema.as_ref())?);
380-
let expr = CastColumnExpr::new(
408+
let expr = CastColumnExpr::new_with_schema(
381409
column,
382410
Arc::new(outer_field.clone()),
383411
Arc::new(target_field.clone()),
384412
None,
413+
Arc::clone(&schema),
385414
)?;
386415

387416
let result = expr.evaluate(&batch)?;
@@ -428,11 +457,12 @@ mod tests {
428457
);
429458
let literal =
430459
Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct))));
431-
let expr = CastColumnExpr::new(
460+
let expr = CastColumnExpr::new_with_schema(
432461
literal,
433462
Arc::new(input_field.clone()),
434463
Arc::new(target_field.clone()),
435464
None,
465+
Arc::clone(&schema),
436466
)?;
437467

438468
let batch = RecordBatch::new_empty(Arc::clone(&schema));

datafusion/physical-expr/src/intervals/utils.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,14 @@ mod tests {
214214

215215
let column_expr = col("a", &schema).unwrap();
216216
let cast_expr = Arc::new(
217-
CastColumnExpr::new(column_expr, input_field, target_field, None)
218-
.expect("cast column expr"),
217+
CastColumnExpr::new_with_schema(
218+
column_expr,
219+
input_field,
220+
target_field,
221+
None,
222+
Arc::clone(&schema),
223+
)
224+
.expect("cast column expr"),
219225
) as Arc<dyn PhysicalExpr>;
220226

221227
assert!(check_support(&cast_expr, &schema));

datafusion/physical-expr/src/simplifier/unwrap_cast.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,14 @@ mod tests {
222222
// Create: cast_column(c1 as INT64) > INT64(10)
223223
let column_expr = col("c1", &schema).unwrap();
224224
let cast_expr = Arc::new(
225-
CastColumnExpr::new(column_expr, input_field, target_field, None)
226-
.expect("cast column expr"),
225+
CastColumnExpr::new_with_schema(
226+
column_expr,
227+
input_field,
228+
target_field,
229+
None,
230+
Arc::new(schema.clone()),
231+
)
232+
.expect("cast column expr"),
227233
);
228234
let literal_expr = lit(10i64);
229235
let binary_expr =

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ pub fn parse_physical_expr(
364364
e.safe,
365365
e.format_options.as_ref(),
366366
)?;
367-
Arc::new(CastColumnExpr::new(
367+
Arc::new(CastColumnExpr::new_with_schema(
368368
parse_required_physical_expr(
369369
e.expr.as_deref(),
370370
ctx,
@@ -375,6 +375,7 @@ pub fn parse_physical_expr(
375375
Arc::new(Field::try_from(input_field)?),
376376
Arc::new(Field::try_from(target_field)?),
377377
cast_options,
378+
Arc::new(input_schema.clone()),
378379
)?)
379380
}
380381
ExprType::TryCast(e) => Arc::new(TryCastExpr::new(

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,19 +219,20 @@ fn roundtrip_cast_column_expr() -> Result<()> {
219219
safe: true,
220220
format_options,
221221
};
222-
let expr: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new(
222+
let input_schema = Schema::new(vec![input_field.clone()]);
223+
let expr: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new_with_schema(
223224
Arc::new(Column::new("a", 0)),
224225
Arc::new(input_field.clone()),
225226
Arc::new(target_field.clone()),
226227
Some(cast_options.clone()),
228+
Arc::new(input_schema.clone()),
227229
)?);
228230

229231
let ctx = SessionContext::new();
230232
let codec = DefaultPhysicalExtensionCodec {};
231233
let proto = datafusion_proto::physical_plan::to_proto::serialize_physical_expr(
232234
&expr, &codec,
233235
)?;
234-
let input_schema = Schema::new(vec![input_field.clone()]);
235236
let round_trip = datafusion_proto::physical_plan::from_proto::parse_physical_expr(
236237
&proto,
237238
&ctx.task_ctx(),
@@ -244,11 +245,12 @@ fn roundtrip_cast_column_expr() -> Result<()> {
244245
.downcast_ref::<CastColumnExpr>()
245246
.ok_or_else(|| internal_datafusion_err!("Expected CastColumnExpr"))?;
246247

247-
let expected = CastColumnExpr::new(
248+
let expected = CastColumnExpr::new_with_schema(
248249
Arc::new(Column::new("a", 0)),
249250
Arc::new(input_field.clone()),
250251
Arc::new(target_field.clone()),
251252
Some(cast_options),
253+
Arc::new(input_schema.clone()),
252254
)?;
253255

254256
assert_eq!(cast_expr, &expected);
@@ -274,19 +276,20 @@ fn roundtrip_cast_column_expr_with_target_field_change() -> Result<()> {
274276
let target_field =
275277
Field::new("payload_cast", DataType::Utf8, false).with_metadata(target_metadata);
276278

277-
let expr: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new(
279+
let input_schema = Schema::new(vec![input_field.clone()]);
280+
let expr: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new_with_schema(
278281
Arc::new(Column::new("payload", 0)),
279282
Arc::new(input_field.clone()),
280283
Arc::new(target_field.clone()),
281284
None,
285+
Arc::new(input_schema.clone()),
282286
)?);
283287

284288
let ctx = SessionContext::new();
285289
let codec = DefaultPhysicalExtensionCodec {};
286290
let proto = datafusion_proto::physical_plan::to_proto::serialize_physical_expr(
287291
&expr, &codec,
288292
)?;
289-
let input_schema = Schema::new(vec![input_field.clone()]);
290293
let round_trip = datafusion_proto::physical_plan::from_proto::parse_physical_expr(
291294
&proto,
292295
&ctx.task_ctx(),

0 commit comments

Comments
 (0)