From f5dd7f5cb8dff976935fcfa04821bb33903a88f1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:24:24 -0500 Subject: [PATCH 01/18] passing tests --- .../user_defined_scalar_functions.rs | 5 +- datafusion/expr/src/expr.rs | 52 ++++++- datafusion/expr/src/expr_schema.rs | 13 +- datafusion/expr/src/tree_node.rs | 4 +- datafusion/functions/src/core/arrow_cast.rs | 2 +- datafusion/physical-expr/src/planner.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 +- datafusion/sql/src/unparser/expr.rs | 138 +++++++----------- .../src/logical_plan/producer/expr/cast.rs | 12 +- .../src/logical_plan/producer/types.rs | 17 ++- 10 files changed, 140 insertions(+), 113 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 1361091a4cb5a..a6f7e89eeec9e 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -713,10 +713,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, - }) + Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64)) }; // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 714925dadc7ab..17e9de5e25e45 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -796,13 +796,23 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: FieldRef, } impl Cast { /// Create a new Cast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + data_type: Field::new("", data_type, true).into(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { + expr, + data_type: field, + } } } @@ -812,13 +822,23 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: FieldRef, } impl TryCast { /// Create a new TryCast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + data_type: Field::new("", data_type, true).into(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { + expr, + data_type: field, + } } } @@ -3284,10 +3304,28 @@ impl Display for Expr { write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type})") + if data_type.metadata().is_empty() { + write!(f, "CAST({expr} AS {})", data_type.data_type()) + } else { + write!( + f, + "CAST({expr} AS {}<{:?}>)", + data_type.data_type(), + data_type.metadata() + ) + } } Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type})") + if data_type.metadata().is_empty() { + write!(f, "TRY_CAST({expr} AS {})", data_type.data_type()) + } else { + write!( + f, + "TRY_CAST({expr} AS {}<{:?}>)", + data_type.data_type(), + data_type.metadata() + ) + } } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3673,7 +3711,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: DataType::Utf8, + data_type: Field::new("", DataType::Utf8, true).into(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 94f5b0480b651..3f07d2e0d71e9 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -130,7 +130,9 @@ impl ExprSchemable for Expr { .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + | Expr::TryCast(TryCast { data_type, .. }) => { + Ok(data_type.data_type().clone()) + } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list @@ -633,7 +635,14 @@ impl ExprSchemable for Expr { // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), Expr::Cast(Cast { expr, data_type }) => expr .to_field(schema) - .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(|(_, f)| { + f.as_ref() + .clone() + .with_data_type(data_type.data_type().clone()) + .with_metadata(f.metadata().clone()) + // TODO: should nullability be overridden here or derived from the + // input expression? + }) .map(Arc::new), Expr::Placeholder(Placeholder { id: _, diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f80608..e949bd71a71f5 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -222,10 +222,10 @@ impl TreeNode for Expr { }), Expr::Cast(Cast { expr, data_type }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), + .update_data(|be| Expr::Cast(Cast::new_from_field(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index c4e58601cd106..09d9f483e4db7 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -171,7 +171,7 @@ impl ScalarUDFImpl for ArrowCastFunc { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: target_type, + data_type: Field::new("", target_type, true).into(), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd56..61b85724ff454 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -290,12 +290,14 @@ pub fn create_physical_expr( Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + // TODO: this drops extension metadata associated with the cast + data_type.data_type().clone(), ), Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + // TODO: this drops extension metadata associated with the cast + data_type.data_type().clone(), ), Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..346d45c05187e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -525,7 +525,7 @@ pub fn serialize_expr( Expr::Cast(Cast { expr, data_type }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(data_type.data_type().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), @@ -534,7 +534,7 @@ pub fn serialize_expr( Expr::TryCast(TryCast { expr, data_type }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(data_type.data_type().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 10c34d5a4df7b..f7d0134ff696c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -37,6 +37,7 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, + Field, FieldRef, }; use arrow::util::display::array_value_to_string; use datafusion_common::{ @@ -1133,24 +1134,27 @@ impl Unparser<'_> { // Explicit type cast on ast::Expr::Value is not needed by underlying engine for certain types // For example: CAST(Utf8("binary_value") AS Binary) and CAST(Utf8("dictionary_value") AS Dictionary) - fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> Result { + fn cast_to_sql(&self, expr: &Expr, field: &FieldRef) -> Result { let inner_expr = self.expr_to_sql_inner(expr)?; + let data_type = field.data_type(); match inner_expr { ast::Expr::Value(_) => match data_type { - DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView => { + DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView + if field.metadata().is_empty() => + { Ok(inner_expr) } _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), }, _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), } @@ -1672,7 +1676,8 @@ impl Unparser<'_> { })) } - fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { + fn arrow_dtype_to_ast_dtype(&self, field: &FieldRef) -> Result { + let data_type = field.data_type(); match data_type { DataType::Null => { not_impl_err!("Unsupported DataType: conversion: {data_type}") @@ -1745,7 +1750,9 @@ impl Unparser<'_> { DataType::Union(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype( + &Field::new("", val.as_ref().clone(), true).into(), + ), DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1885,34 +1892,25 @@ mod tests { r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)), r#"CAST(a AS DATETIME)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+08:00".into()), - ), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), + )), r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Millisecond, None), + )), r#"CAST(a AS TIMESTAMP)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::UInt32, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::UInt32)), r#"CAST(a AS INTEGER UNSIGNED)"#, ), ( @@ -2227,10 +2225,7 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Decimal128(10, -2), - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Decimal128(10, -2))), r#"CAST(a AS DECIMAL(12,0))"#, ), ( @@ -2367,10 +2362,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2392,10 +2384,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2625,23 +2614,23 @@ mod tests { fn test_cast_value_to_binary_expr() { let tests = [ ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::Binary, - }), + DataType::Binary, + )), "'blah'", ), ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::BinaryView, - }), + DataType::BinaryView, + )), "'blah'", ), ]; @@ -2672,10 +2661,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2758,10 +2744,7 @@ mod tests { [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2786,10 +2769,7 @@ mod tests { [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int32, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2825,10 +2805,7 @@ mod tests { (&mysql_dialect, ×tamp_with_tz, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: data_type.clone(), - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type.clone())); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2881,10 +2858,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2924,13 +2898,13 @@ mod tests { #[test] fn test_cast_value_to_dict_expr() { let tests = [( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("variation".to_string())), None, )), - data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), - }), + DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), + )), "'variation'", )]; for (value, expected) in tests { @@ -2962,10 +2936,7 @@ mod tests { datafusion_functions::math::round::RoundFunc::new(), )), args: vec![ - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); @@ -3127,10 +3098,12 @@ mod tests { let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ))?; + let arrow_field = Arc::new(Field::new( + "", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Varchar(None)); @@ -3144,7 +3117,8 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?; + let arrow_field = Arc::new(Field::new("", DataType::Utf8View, true)); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Char(None)); @@ -3212,10 +3186,10 @@ mod tests { let dialect: Arc = Arc::new(SqliteDialect {}); let unparser = Unparser::new(dialect.as_ref()); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Nanosecond, None), - }); + let expr = Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, None), + )); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 71c2140bac8bf..2b9838406f200 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::logical_plan::producer::{to_substrait_type_from_field, SubstraitProducer}; use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; use datafusion::common::{DFSchemaRef, ScalarValue}; use datafusion::logical_expr::{Cast, Expr, TryCast}; @@ -39,8 +39,8 @@ pub fn from_cast( let lit = Literal { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - producer, data_type, true, + literal_type: Some(LiteralType::Null(to_substrait_type_from_field( + producer, data_type, )?)), }; return Ok(Expression { @@ -51,7 +51,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type_from_field(producer, data_type)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -68,7 +68,7 @@ pub fn from_try_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type_from_field(producer, data_type)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, @@ -80,7 +80,7 @@ pub fn from_try_cast( mod tests { use super::*; use crate::logical_plan::producer::{ - to_substrait_extended_expr, DefaultSubstraitProducer, + to_substrait_extended_expr, to_substrait_type, DefaultSubstraitProducer, }; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema; diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 0613ed07be2a5..9019170b0225a 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -27,7 +27,7 @@ use crate::variation_const::{ TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; -use datafusion::arrow::datatypes::{DataType, IntervalUnit}; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; use datafusion::common::{not_impl_err, plan_err, DFSchemaRef}; use substrait::proto::{r#type, NamedStruct}; @@ -36,12 +36,19 @@ pub(crate) fn to_substrait_type( dt: &DataType, nullable: bool, ) -> datafusion::common::Result { - let nullability = if nullable { + to_substrait_type_from_field(producer, &Field::new("", dt.clone(), nullable).into()) +} + +pub(crate) fn to_substrait_type_from_field( + producer: &mut impl SubstraitProducer, + dt: &FieldRef, +) -> datafusion::common::Result { + let nullability = if dt.is_nullable() { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; - match dt { + match dt.data_type() { DataType::Null => { let type_anchor = producer.register_type(NULL_TYPE_NAME.to_string()); Ok(substrait::proto::Type { @@ -310,8 +317,8 @@ pub(crate) fn to_substrait_type( _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, DataType::Dictionary(key_type, value_type) => { - let key_type = to_substrait_type(producer, key_type, nullable)?; - let value_type = to_substrait_type(producer, value_type, nullable)?; + let key_type = to_substrait_type(producer, key_type, dt.is_nullable())?; + let value_type = to_substrait_type(producer, value_type, dt.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), From 89547f4d1f861d0b15e6e013d9b9331d2d27002e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:25:50 -0500 Subject: [PATCH 02/18] clippy --- datafusion/substrait/src/logical_plan/producer/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 9019170b0225a..a3c0bc17760a2 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -376,7 +376,7 @@ pub(crate) fn to_substrait_named_struct( types: schema .fields() .iter() - .map(|f| to_substrait_type(producer, f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type_from_field(producer, f)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, From 1e5e1b0c66cec03ab7e2d3f549933a6a9916c2aa Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:40:48 -0500 Subject: [PATCH 03/18] proto --- datafusion/proto/proto/datafusion.proto | 4 ++ datafusion/proto/src/generated/pbjson.rs | 72 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 14 ++++ .../proto/src/logical_plan/from_proto.rs | 9 ++- datafusion/proto/src/logical_plan/to_proto.rs | 4 ++ .../src/logical_plan/producer/types.rs | 4 +- 6 files changed, 102 insertions(+), 5 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 789176862bf00..7bf99473910af 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -593,11 +593,15 @@ message WhenThen { message CastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message TryCastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message SortExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 230bfa495a4b3..b2a54907eb046 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1834,6 +1834,12 @@ impl serde::Serialize for CastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -1841,6 +1847,12 @@ impl serde::Serialize for CastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -1854,12 +1866,16 @@ impl<'de> serde::Deserialize<'de> for CastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1883,6 +1899,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1904,6 +1922,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1918,11 +1938,27 @@ impl<'de> serde::Deserialize<'de> for CastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(CastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } @@ -22266,6 +22302,12 @@ impl serde::Serialize for TryCastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.TryCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -22273,6 +22315,12 @@ impl serde::Serialize for TryCastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -22286,12 +22334,16 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22315,6 +22367,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22336,6 +22390,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -22350,11 +22406,27 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(TryCastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b2d0bc7751f9b..fe11bbf970ed3 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -919,6 +919,13 @@ pub struct CastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TryCastNode { @@ -926,6 +933,13 @@ pub struct TryCastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d41011845272b..71f39ce4ddaea 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -528,7 +528,8 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field)))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -538,7 +539,11 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + Ok(Expr::TryCast(TryCast::new_from_field( + expr, + Arc::new(field), + ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 346d45c05187e..187aef25b8a8e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -526,6 +526,8 @@ pub fn serialize_expr( let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.data_type().try_into()?), + metadata: data_type.metadata().clone(), + nullable: Some(data_type.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), @@ -535,6 +537,8 @@ pub fn serialize_expr( let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.data_type().try_into()?), + metadata: data_type.metadata().clone(), + nullable: Some(data_type.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index a3c0bc17760a2..5631fed56d8c2 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -331,9 +331,7 @@ pub(crate) fn to_substrait_type_from_field( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| { - to_substrait_type(producer, field.data_type(), field.is_nullable()) - }) + .map(|f| to_substrait_type_from_field(producer, f)) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { From b23bb4252962840419b850055266b7f879528f06 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 11:45:53 -0500 Subject: [PATCH 04/18] fmt --- .../src/logical_plan/producer/types.rs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 5631fed56d8c2..c691baf8f506a 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -295,16 +295,9 @@ pub(crate) fn to_substrait_type_from_field( } DataType::Map(inner, _) => match inner.data_type() { DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - producer, - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - producer, - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; + let key_type = to_substrait_type_from_field(producer, &key_and_value[0])?; + let value_type = + to_substrait_type_from_field(producer, &key_and_value[1])?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -317,8 +310,14 @@ pub(crate) fn to_substrait_type_from_field( _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, DataType::Dictionary(key_type, value_type) => { - let key_type = to_substrait_type(producer, key_type, dt.is_nullable())?; - let value_type = to_substrait_type(producer, value_type, dt.is_nullable())?; + let key_type = to_substrait_type_from_field( + producer, + &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), + )?; + let value_type = to_substrait_type_from_field( + producer, + &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), + )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -331,7 +330,7 @@ pub(crate) fn to_substrait_type_from_field( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|f| to_substrait_type_from_field(producer, f)) + .map(|field| to_substrait_type_from_field(producer, field)) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { From 4ac02964ed4500eb45b3bdc13cd05361ee540036 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 12:12:42 -0500 Subject: [PATCH 05/18] use the helper --- datafusion/expr/src/expr.rs | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 17e9de5e25e45..7c3e900aa4fb4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,6 +32,7 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -3304,28 +3305,18 @@ impl Display for Expr { write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - if data_type.metadata().is_empty() { - write!(f, "CAST({expr} AS {})", data_type.data_type()) - } else { - write!( - f, - "CAST({expr} AS {}<{:?}>)", - data_type.data_type(), - data_type.metadata() - ) - } + let formatted = format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()), + ); + write!(f, "CAST({expr} AS {})", formatted) } Expr::TryCast(TryCast { expr, data_type }) => { - if data_type.metadata().is_empty() { - write!(f, "TRY_CAST({expr} AS {})", data_type.data_type()) - } else { - write!( - f, - "TRY_CAST({expr} AS {}<{:?}>)", - data_type.data_type(), - data_type.metadata() - ) - } + let formatted = format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()), + ); + write!(f, "TRY_CAST({expr} AS {})", formatted) } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), From 32c6d2669c823a6caf2e6181dff7836f9982d604 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 20:37:40 -0500 Subject: [PATCH 06/18] clippy --- datafusion/expr/src/expr.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7c3e900aa4fb4..0f236b7517d21 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3309,14 +3309,14 @@ impl Display for Expr { data_type.data_type(), Some(data_type.metadata()), ); - write!(f, "CAST({expr} AS {})", formatted) + write!(f, "CAST({expr} AS {formatted})") } Expr::TryCast(TryCast { expr, data_type }) => { let formatted = format_type_and_metadata( data_type.data_type(), Some(data_type.metadata()), ); - write!(f, "TRY_CAST({expr} AS {})", formatted) + write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), From e2e695eb9ed97913eed6fcdc3521b7bcd5cc27ca Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 21:56:54 -0500 Subject: [PATCH 07/18] test not suported --- datafusion/physical-expr/src/planner.rs | 95 +++++++++++++++++++++---- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 61b85724ff454..ad22f90e65ac5 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,7 +25,7 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; -use datafusion_common::metadata::FieldMetadata; +use datafusion_common::metadata::{format_type_and_metadata, FieldMetadata}; use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; @@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, + binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -287,18 +287,50 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - // TODO: this drops extension metadata associated with the cast - data_type.data_type().clone(), - ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - // TODO: this drops extension metadata associated with the cast - data_type.data_type().clone(), - ), + Expr::Cast(Cast { expr, data_type }) => { + if !data_type.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "Cast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()) + ) + ); + } + + expressions::cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + data_type.data_type().clone(), + ) + } + Expr::TryCast(TryCast { expr, data_type }) => { + if !data_type.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "TryCast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()) + ) + ); + } + + expressions::try_cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + data_type.data_type().clone(), + ) + } Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } @@ -419,6 +451,7 @@ mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; + use datafusion_common::datatype::DataTypeExt; use datafusion_expr::{col, lit}; use super::*; @@ -447,4 +480,38 @@ mod tests { Ok(()) } + + #[test] + fn test_cast_to_extension_type() -> Result<()> { + let extension_field_type = Arc::new( + DataType::FixedSizeBinary(16) + .into_nullable_field() + .with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + ); + let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(expr.clone()), + extension_field_type.clone(), + )); + let err = + create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + + let try_cast_expr = Expr::TryCast(TryCast::new_from_field( + Box::new(expr.clone()), + extension_field_type.clone(), + )); + let err = create_physical_expr( + &try_cast_expr, + &DFSchema::empty(), + &ExecutionProps::new(), + ) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + Ok(()) + } } From 64980abe59d40f9a3389cb959f063b7c257bd1bc Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 22:15:21 -0500 Subject: [PATCH 08/18] rename data type to field --- .../provider_filter_pushdown.rs | 2 +- datafusion/expr/src/expr.rs | 59 +++++++------------ datafusion/expr/src/expr_rewriter/order_by.rs | 8 +-- datafusion/expr/src/expr_schema.rs | 9 ++- datafusion/expr/src/tree_node.rs | 8 +-- datafusion/functions/src/core/arrow_cast.rs | 4 +- .../optimizer/src/eliminate_outer_join.rs | 4 +- datafusion/physical-expr/src/planner.rs | 22 +++---- datafusion/proto/src/logical_plan/to_proto.rs | 16 ++--- datafusion/sql/src/unparser/expr.rs | 8 +-- .../src/logical_plan/producer/expr/cast.rs | 10 ++-- 11 files changed, 64 insertions(+), 86 deletions(-) diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54ba..ca01a0657988c 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -183,7 +183,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0f236b7517d21..b54ba4217302e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,6 +32,7 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -797,7 +798,7 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: FieldRef, + pub field: FieldRef, } impl Cast { @@ -805,15 +806,12 @@ impl Cast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, - data_type: Field::new("", data_type, true).into(), + field: data_type.into_nullable_field_ref(), } } pub fn new_from_field(expr: Box, field: FieldRef) -> Self { - Self { - expr, - data_type: field, - } + Self { expr, field } } } @@ -823,7 +821,7 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: FieldRef, + pub field: FieldRef, } impl TryCast { @@ -831,15 +829,12 @@ impl TryCast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, - data_type: Field::new("", data_type, true).into(), + field: data_type.into_nullable_field_ref(), } } pub fn new_from_field(expr: Box, field: FieldRef) -> Self { - Self { - expr, - data_type: field, - } + Self { expr, field } } } @@ -2273,23 +2268,23 @@ impl NormalizeEq for Expr { ( Expr::Cast(Cast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::Cast(Cast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), ) | ( Expr::TryCast(TryCast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::TryCast(TryCast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), - ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ) => self_field == other_field && self_expr.normalize_eq(other_expr), ( Expr::ScalarFunction(ScalarFunction { func: self_func, @@ -2605,15 +2600,9 @@ impl HashNode for Expr { when_then_expr: _when_then_expr, else_expr: _else_expr, }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { - data_type.hash(state); + Expr::Cast(Cast { expr: _expr, field }) + | Expr::TryCast(TryCast { expr: _expr, field }) => { + field.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(state); @@ -3304,18 +3293,14 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - let formatted = format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()), - ); + Expr::Cast(Cast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); write!(f, "CAST({expr} AS {formatted})") } - Expr::TryCast(TryCast { expr, data_type }) => { - let formatted = format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()), - ); + Expr::TryCast(TryCast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), @@ -3702,7 +3687,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: Field::new("", DataType::Utf8, true).into(), + field: DataType::Utf8.into_nullable_field_ref(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c21c6e6222a05..3cb02ec2e39b9 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection( if let Some(found) = found { return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { + Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast { expr: Box::new(found), - data_type, + field, }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast { expr: Box::new(found), - data_type, + field, }), _ => found, })); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3f07d2e0d71e9..86fcb2dc3a2db 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -129,9 +129,8 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => { - Ok(data_type.data_type().clone()) + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + Ok(field.data_type().clone()) } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; @@ -633,12 +632,12 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .to_field(schema) .map(|(_, f)| { f.as_ref() .clone() - .with_data_type(data_type.data_type().clone()) + .with_data_type(field.data_type().clone()) .with_metadata(f.metadata().clone()) // TODO: should nullability be overridden here or derived from the // input expression? diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e949bd71a71f5..2fbd57a11ecbb 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -220,12 +220,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new_from_field(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))), + Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 09d9f483e4db7..34b25c559d01d 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,6 +19,7 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; @@ -164,6 +165,7 @@ impl ScalarUDFImpl for ArrowCastFunc { let arg = args.pop().unwrap(); let source_type = info.get_data_type(&arg)?; + // TODO: check type equality for real let new_expr = if source_type == target_type { // the argument's data type is already the correct type arg @@ -171,7 +173,7 @@ impl ScalarUDFImpl for ArrowCastFunc { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: Field::new("", target_type, true).into(), + field: target_type.into_nullable_field_ref(), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f2766..160c09cde2f5a 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -289,8 +289,8 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( + Expr::Cast(Cast { expr, field: _ }) + | Expr::TryCast(TryCast { expr, field: _ }) => extract_non_nullable_columns( expr, non_nullable_cols, left_schema, diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index ad22f90e65ac5..f3de5b43118bf 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -287,8 +287,8 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => { - if !data_type.metadata().is_empty() { + Expr::Cast(Cast { expr, field }) => { + if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; return plan_err!( "Cast from {} to {} is not supported", @@ -296,21 +296,18 @@ pub fn create_physical_expr( src_field.data_type(), Some(src_field.metadata()), ), - format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()) - ) + format_type_and_metadata(field.data_type(), Some(field.metadata())) ); } expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.data_type().clone(), + field.data_type().clone(), ) } - Expr::TryCast(TryCast { expr, data_type }) => { - if !data_type.metadata().is_empty() { + Expr::TryCast(TryCast { expr, field }) => { + if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; return plan_err!( "TryCast from {} to {} is not supported", @@ -318,17 +315,14 @@ pub fn create_physical_expr( src_field.data_type(), Some(src_field.metadata()), ), - format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()) - ) + format_type_and_metadata(field.data_type(), Some(field.metadata())) ); } expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.data_type().clone(), + field.data_type().clone(), ) } Expr::Not(expr) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 187aef25b8a8e..c1b9f30f2de88 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -522,23 +522,23 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, field }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.data_type().try_into()?), - metadata: data_type.metadata().clone(), - nullable: Some(data_type.is_nullable()), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.data_type().try_into()?), - metadata: data_type.metadata().clone(), - nullable: Some(data_type.is_nullable()), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f7d0134ff696c..b44f503b3bd36 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -189,9 +189,7 @@ impl Unparser<'_> { end_token: AttachedToken::empty(), }) } - Expr::Cast(Cast { expr, data_type }) => { - Ok(self.cast_to_sql(expr, data_type)?) - } + Expr::Cast(Cast { expr, field }) => Ok(self.cast_to_sql(expr, field)?), Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(window_fun) => { @@ -462,12 +460,12 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 2b9838406f200..889a285eb6a0c 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -29,7 +29,7 @@ pub fn from_cast( cast: &Cast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let Cast { expr, data_type } = cast; + let Cast { expr, field } = cast; // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null if let Expr::Literal(lit, _) = expr.as_ref() { // only the untyped(a null scalar value) null literal need this special handling @@ -40,7 +40,7 @@ pub fn from_cast( nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type_from_field( - producer, data_type, + producer, field, )?)), }; return Ok(Expression { @@ -51,7 +51,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type_from_field(producer, data_type)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -64,11 +64,11 @@ pub fn from_try_cast( cast: &TryCast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let TryCast { expr, data_type } = cast; + let TryCast { expr, field } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type_from_field(producer, data_type)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, From e6a7e909502c2c3a4000822b59c8cebb3b5127e6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 22:48:00 -0500 Subject: [PATCH 09/18] clippy --- datafusion/physical-expr/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f3de5b43118bf..bc80c6df1aa23 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -488,7 +488,7 @@ mod tests { let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); let cast_expr = Expr::Cast(Cast::new_from_field( Box::new(expr.clone()), - extension_field_type.clone(), + Arc::clone(&extension_field_type), )); let err = create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) @@ -497,7 +497,7 @@ mod tests { let try_cast_expr = Expr::TryCast(TryCast::new_from_field( Box::new(expr.clone()), - extension_field_type.clone(), + Arc::clone(&extension_field_type), )); let err = create_physical_expr( &try_cast_expr, From 0955d58bf595b5b6efa5044b09d4ef8b6874d165 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:08:00 -0500 Subject: [PATCH 10/18] maybe better substrait consumer integration --- .../src/logical_plan/consumer/expr/cast.rs | 11 +++++----- .../src/logical_plan/consumer/types.rs | 20 ++++++++++++++++++- .../src/logical_plan/producer/types.rs | 12 +++++------ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs index 5e8d3d93065f4..ff88464b8bb05 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::consumer::types::from_substrait_type_without_names; -use crate::logical_plan::consumer::SubstraitConsumer; +use crate::logical_plan::consumer::{ + field_from_substrait_type_without_names, SubstraitConsumer, +}; use datafusion::common::{substrait_err, DFSchema}; use datafusion::logical_expr::{Cast, Expr, TryCast}; use substrait::proto::expression as substrait_expression; @@ -37,11 +38,11 @@ pub async fn from_cast( ) .await?, ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; + let field = field_from_substrait_type_without_names(consumer, output_type)?; if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + Ok(Expr::TryCast(TryCast::new_from_field(input_expr, field))) } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) + Ok(Expr::Cast(Cast::new_from_field(input_expr, field))) } } None => substrait_err!("Cast expression without output type is not allowed"), diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index ef1000a1ccdba..91e2e43a41c36 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -34,14 +34,22 @@ use crate::variation_const::{ }; use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME}; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::datatype::DataTypeExt; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, }; use std::sync::Arc; use substrait::proto::{r#type, NamedStruct, Type}; +pub(crate) fn field_from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + Ok(from_substrait_type_without_names(consumer, dt)?.into_nullable_field_ref()) +} + pub(crate) fn from_substrait_type_without_names( consumer: &impl SubstraitConsumer, dt: &Type, @@ -49,6 +57,16 @@ pub(crate) fn from_substrait_type_without_names( from_substrait_type(consumer, dt, &[], &mut 0) } +pub fn field_from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + // We could add nullability here now that we are returning a Field + Ok(from_substrait_type(consumer, dt, dfs_names, name_idx)?.into_nullable_field_ref()) +} + pub fn from_substrait_type( consumer: &impl SubstraitConsumer, dt: &Type, diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index c691baf8f506a..aaeebafca12f5 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -41,14 +41,14 @@ pub(crate) fn to_substrait_type( pub(crate) fn to_substrait_type_from_field( producer: &mut impl SubstraitProducer, - dt: &FieldRef, + field: &FieldRef, ) -> datafusion::common::Result { - let nullability = if dt.is_nullable() { + let nullability = if field.is_nullable() { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; - match dt.data_type() { + match field.data_type() { DataType::Null => { let type_anchor = producer.register_type(NULL_TYPE_NAME.to_string()); Ok(substrait::proto::Type { @@ -312,11 +312,11 @@ pub(crate) fn to_substrait_type_from_field( DataType::Dictionary(key_type, value_type) => { let key_type = to_substrait_type_from_field( producer, - &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), + &Field::new("", key_type.as_ref().clone(), field.is_nullable()).into(), )?; let value_type = to_substrait_type_from_field( producer, - &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), + &Field::new("", value_type.as_ref().clone(), field.is_nullable()).into(), )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { @@ -356,7 +356,7 @@ pub(crate) fn to_substrait_type_from_field( precision: *p as i32, })), }), - _ => not_impl_err!("Unsupported cast type: {dt}"), + _ => not_impl_err!("Unsupported cast type: {field}"), } } From 2e103a59756ce7520f71bf63616138f4827cbc5c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:20:51 -0500 Subject: [PATCH 11/18] no need to update this --- datafusion/functions/src/core/arrow_cast.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 34b25c559d01d..6780cc4d197ad 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -165,7 +165,6 @@ impl ScalarUDFImpl for ArrowCastFunc { let arg = args.pop().unwrap(); let source_type = info.get_data_type(&arg)?; - // TODO: check type equality for real let new_expr = if source_type == target_type { // the argument's data type is already the correct type arg From 00237be47839d7d66683b3755445ba95acb8cfc1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:26:23 -0500 Subject: [PATCH 12/18] comment about nullability --- datafusion/expr/src/expr_schema.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 86fcb2dc3a2db..b1dc95d9d2f17 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -635,12 +635,13 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { expr, field }) => expr .to_field(schema) .map(|(_, f)| { + // This currently propagates the nullability of the input + // expression as the resulting physical expression does + // not currently consider the nullability specified here f.as_ref() .clone() .with_data_type(field.data_type().clone()) .with_metadata(f.metadata().clone()) - // TODO: should nullability be overridden here or derived from the - // input expression? }) .map(Arc::new), Expr::Placeholder(Placeholder { From b402c442f5ded919b7f0c7bea326506dd9ffc366 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 1 Dec 2025 15:14:13 -0600 Subject: [PATCH 13/18] nope use the other version --- datafusion/expr/src/expr_schema.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 75cc2695b9c46..c7adc1bdc2f67 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -643,7 +643,16 @@ impl ExprSchemable for Expr { // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), Expr::Cast(Cast { expr, field }) => expr .to_field(schema) - .map(|(_, f)| f.retyped(data_type.clone())), + .map(|(_, f)| { + // This currently propagates the nullability of the input + // expression as the resulting physical expression does + // not currently consider the nullability specified here + f.as_ref() + .clone() + .with_data_type(field.data_type().clone()) + .with_metadata(f.metadata().clone()) + }) + .map(Arc::new), Expr::Placeholder(Placeholder { id: _, field: Some(field), From 5657e89cd42f0697f21521938454440f7a0b2e08 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 8 Dec 2025 14:32:17 -0600 Subject: [PATCH 14/18] add upgrade guide --- docs/source/library-user-guide/upgrading.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index d40fee093816a..487817081b64b 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -332,6 +332,21 @@ let config = FileScanConfigBuilder::new(url, source) **Handling projections in `FileSource`:** +### `Expr::ScalarVariable` and `Expr::Cast` now store type information as a `FieldRef` + +Code that explicitly constructed `Expr::ScalarVariable(DataType::xxx)` must now convert +the `DataType` into a `FieldRef` (e.g., `Expr::ScalarVariable(DataType::Int8.into_nullable_field_ref())`). +Implementations of a custom `ContextProvider` may implement `get_variable_field()` to provide +planning information for variables with Arrow extension types or explicit nullability. See +[#18243](https://github.com/apache/datafusion/pull/18243) for more information. + +Similarly, the `Cast` type wrapped by `Expr::Cast` was updated to use a `FieldRef` as the +underlying storage to represent casts to Arrow extension types. Code that constructed +or matched casts in the form `Expr::Cast(Cast { expr, data_type })` will need to be updated. +Casts to `DataType` can continue to be constructed with `Expr::Cast(Cast::new(expr, DataType::...))`, +although we reccomend `Expr::Cast(Cast::new_from_field(expr, existing_field_ref))` when +constructing casts programmatically to avoid dropping extension type metdata. + ## DataFusion `51.0.0` ### `arrow` / `parquet` updated to 57.0.0 From f4d267372fe2cd823beb138015c70d846340eea2 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 8 Dec 2025 14:34:14 -0600 Subject: [PATCH 15/18] fmt --- datafusion/physical-expr/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index b3836b6dfcf13..f7a0f21563836 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,7 +25,7 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; -use datafusion_common::metadata::{format_type_and_metadata, FieldMetadata}; +use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err, }; @@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::VarType; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, + Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, binary_expr, lit, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 From 55de0cb8009a43595157832aa54b76bd026aae15 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 8 Dec 2025 16:19:20 -0600 Subject: [PATCH 16/18] typos --- docs/source/library-user-guide/upgrading.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 034e0a0ec0701..41ccdb6dd4aba 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -392,8 +392,8 @@ Similarly, the `Cast` type wrapped by `Expr::Cast` was updated to use a `FieldRe underlying storage to represent casts to Arrow extension types. Code that constructed or matched casts in the form `Expr::Cast(Cast { expr, data_type })` will need to be updated. Casts to `DataType` can continue to be constructed with `Expr::Cast(Cast::new(expr, DataType::...))`, -although we reccomend `Expr::Cast(Cast::new_from_field(expr, existing_field_ref))` when -constructing casts programmatically to avoid dropping extension type metdata. +although we recommend `Expr::Cast(Cast::new_from_field(expr, existing_field_ref))` when +constructing casts programmatically to avoid dropping extension type metadata. ## DataFusion `51.0.0` From 7048c7e137e9a9c22f3032b73cc2a2b827818b9c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 16 Dec 2025 11:55:48 -0600 Subject: [PATCH 17/18] comments --- datafusion/expr/src/expr_schema.rs | 17 +++++++++++------ datafusion/proto/src/logical_plan/from_proto.rs | 15 ++++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1a756fd0e4cf4..43ce7070602bb 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -627,14 +627,19 @@ impl ExprSchemable for Expr { // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), Expr::Cast(Cast { expr, field }) => expr .to_field(schema) - .map(|(_, f)| { - // This currently propagates the nullability of the input - // expression as the resulting physical expression does - // not currently consider the nullability specified here - f.as_ref() + .map(|(_table_ref, destination_field)| { + // This propagates the nullability of the input rather than + // force the nullability of the destination field. This is + // usually the desired behaviour (i.e., specifying a cast + // destination type usually does not force a user to pick + // nullability, and assuming `true` would prevent the non-nullability + // of the parent expression to make the result eligible for + // optimizations that only apply to non-nullable values). + destination_field + .as_ref() .clone() .with_data_type(field.data_type().clone()) - .with_metadata(f.metadata().clone()) + .with_metadata(destination_field.metadata().clone()) }) .map(Arc::new), Expr::Placeholder(Placeholder { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 71f39ce4ddaea..7acad899fe9d0 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,7 +17,8 @@ use std::sync::Arc; -use arrow::datatypes::Field; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, @@ -527,8 +528,10 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field)))) } ExprType::TryCast(cast) => { @@ -538,8 +541,10 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); Ok(Expr::TryCast(TryCast::new_from_field( expr, Arc::new(field), From 87ea98aa13425822d9cbe9a52dca6d47201cfd5e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 5 Jan 2026 14:25:29 -0600 Subject: [PATCH 18/18] fix build and format --- datafusion/functions/src/core/arrow_cast.rs | 6 +++--- datafusion/physical-expr/src/planner.rs | 9 +++++---- .../substrait/src/logical_plan/consumer/expr/cast.rs | 5 +++-- .../substrait/src/logical_plan/producer/expr/cast.rs | 3 +-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 4896d7862111c..5fa63baac450a 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -20,9 +20,9 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ - Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err, - exec_datafusion_err, utils::take_function_args, types::logical_string, - datatype::DataTypeExt, + Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt, + exec_datafusion_err, exec_err, internal_err, types::logical_string, + utils::take_function_args, }; use std::any::Any; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7cdc425badc0f..5c170700d9833 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -445,8 +445,9 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; - use datafusion_expr::{Operator, col, lit, datatype::DataTypeExt}; - + use datafusion_common::datatype::DataTypeExt; + use datafusion_expr::{Operator, col, lit}; + use super::*; #[test] @@ -505,10 +506,10 @@ mod tests { ) .unwrap_err(); assert!(err.message().contains("arrow.uuid")); - + Ok(()) } - + /// Test that deeply nested expressions do not cause a stack overflow. /// /// This test only runs when the `recursive_protection` feature is enabled, diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs index dc789003d8c04..3dd62afe8f193 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::consumer::SubstraitConsumer; -use crate::logical_plan::consumer::types::from_substrait_type_without_names; +use crate::logical_plan::consumer::{ + SubstraitConsumer, field_from_substrait_type_without_names, +}; use datafusion::common::{DFSchema, substrait_err}; use datafusion::logical_expr::{Cast, Expr, TryCast}; use substrait::proto::expression as substrait_expression; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 8922a8ac9d690..e2140a9a2254f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{to_substrait_type_from_field, SubstraitProducer}; use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type_from_field}; use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; use datafusion::common::{DFSchemaRef, ScalarValue}; @@ -81,7 +80,7 @@ pub fn from_try_cast( mod tests { use super::*; use crate::logical_plan::producer::{ - DefaultSubstraitProducer, to_substrait_type, to_substrait_extended_expr, + DefaultSubstraitProducer, to_substrait_extended_expr, to_substrait_type, }; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema;