Skip to content

Commit 7b32f01

Browse files
committed
feat: Update Cast and TryCast to use FieldRef for result metadata
1 parent 2f5a531 commit 7b32f01

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

datafusion/expr/src/expr.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,7 @@ pub enum GetFieldAccess {
794794
pub struct Cast {
795795
/// The expression being cast
796796
pub expr: Box<Expr>,
797-
/// The `DataType` the expression will yield
798-
// pub data_type: DataType,
797+
/// Field describing the result of the cast, including metadata
799798
pub field: FieldRef,
800799
}
801800

@@ -811,7 +810,7 @@ impl Cast {
811810
pub struct TryCast {
812811
/// The expression being cast
813812
pub expr: Box<Expr>,
814-
/// The `DataType` the expression will yield
813+
/// Field describing the result of the cast, including metadata
815814
pub field: FieldRef,
816815
}
817816

datafusion/expr/src/expr_schema.rs

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,13 @@ impl ExprSchemable for Expr {
580580
func.return_field_from_args(args)
581581
}
582582
// _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
583-
Expr::Cast(Cast { expr, field }) => expr
584-
.to_field(schema)
585-
.map(|(_, f)| {
586-
f.as_ref().clone().with_data_type(field.data_type().clone())
587-
})
588-
.map(Arc::new),
583+
Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, field }) => {
584+
let (_, input_field) = expr.to_field(schema)?;
585+
let mut combined_metadata = FieldMetadata::from(input_field.metadata());
586+
combined_metadata.extend(FieldMetadata::from(field.metadata()));
587+
let field = combined_metadata.add_to_field(field.as_ref().clone());
588+
Ok(Arc::new(field))
589+
}
589590
Expr::Placeholder(Placeholder {
590591
id: _,
591592
field: Some(field),
@@ -595,7 +596,6 @@ impl ExprSchemable for Expr {
595596
| Expr::Not(_)
596597
| Expr::Between(_)
597598
| Expr::Case(_)
598-
| Expr::TryCast(_)
599599
| Expr::InList(_)
600600
| Expr::InSubquery(_)
601601
| Expr::Wildcard { .. }
@@ -782,6 +782,7 @@ mod tests {
782782
use std::collections::HashMap;
783783

784784
use super::*;
785+
use crate::expr::{Cast, TryCast};
785786
use crate::{col, lit, out_ref_col_with_metadata};
786787

787788
use datafusion_common::{internal_err, DFSchema, ScalarValue};
@@ -972,6 +973,63 @@ mod tests {
972973
);
973974
}
974975

976+
#[test]
977+
fn test_cast_metadata_overrides() {
978+
let source_meta = FieldMetadata::from(HashMap::from([
979+
("source".to_string(), "value".to_string()),
980+
("shared".to_string(), "source".to_string()),
981+
]));
982+
let cast_meta = FieldMetadata::from(HashMap::from([
983+
("shared".to_string(), "cast".to_string()),
984+
("cast".to_string(), "value".to_string()),
985+
]));
986+
987+
let schema = MockExprSchema::new()
988+
.with_data_type(DataType::Int32)
989+
.with_metadata(source_meta.clone());
990+
991+
let cast_field = Arc::new(
992+
Field::new("ignored", DataType::Utf8, true)
993+
.with_metadata(cast_meta.to_hashmap()),
994+
);
995+
996+
let expr = col("foo");
997+
let cast_expr = Expr::Cast(Cast::new(Box::new(expr), Arc::clone(&cast_field)));
998+
999+
let mut expected = source_meta.clone();
1000+
expected.extend(cast_meta.clone());
1001+
assert_eq!(expected, cast_expr.metadata(&schema).unwrap());
1002+
}
1003+
1004+
#[test]
1005+
fn test_try_cast_metadata_overrides() {
1006+
let source_meta = FieldMetadata::from(HashMap::from([
1007+
("source".to_string(), "value".to_string()),
1008+
("shared".to_string(), "source".to_string()),
1009+
]));
1010+
let cast_meta = FieldMetadata::from(HashMap::from([
1011+
("shared".to_string(), "cast".to_string()),
1012+
("cast".to_string(), "value".to_string()),
1013+
]));
1014+
1015+
let schema = MockExprSchema::new()
1016+
.with_data_type(DataType::Int32)
1017+
.with_metadata(source_meta.clone());
1018+
1019+
let cast_field = Arc::new(
1020+
Field::new("ignored", DataType::Utf8, true)
1021+
.with_metadata(cast_meta.to_hashmap()),
1022+
);
1023+
1024+
let expr = col("foo");
1025+
let cast_expr =
1026+
Expr::TryCast(TryCast::new(Box::new(expr), Arc::clone(&cast_field)));
1027+
1028+
let mut expected = source_meta.clone();
1029+
expected.extend(cast_meta.clone());
1030+
assert_eq!(expected, cast_expr.metadata(&schema).unwrap());
1031+
}
1032+
9751033
#[derive(Debug)]
9761034
struct MockExprSchema {
9771035
field: Field,

0 commit comments

Comments
 (0)