diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ca1eaa1f958ea..ace3ab61a547c 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -193,7 +193,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 24cade1e80d5a..5bf2a60165cec 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -715,10 +715,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, - }) + Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64)) }; // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7d825ce1d52f..cb3e4b038e145 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,6 +32,8 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::datatype::DataTypeExt; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -796,13 +798,20 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub field: FieldRef, } impl Cast { /// Create a new Cast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + field: data_type.into_nullable_field_ref(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -812,13 +821,20 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub field: FieldRef, } impl TryCast { /// Create a new TryCast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + field: data_type.into_nullable_field_ref(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -2252,23 +2268,23 @@ impl NormalizeEq for Expr { ( Expr::Cast(Cast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::Cast(Cast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), ) | ( Expr::TryCast(TryCast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::TryCast(TryCast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), - ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ) => self_field == other_field && self_expr.normalize_eq(other_expr), ( Expr::ScalarFunction(ScalarFunction { func: self_func, @@ -2584,15 +2600,9 @@ impl HashNode for Expr { when_then_expr: _when_then_expr, else_expr: _else_expr, }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { - data_type.hash(state); + Expr::Cast(Cast { expr: _expr, field }) + | Expr::TryCast(TryCast { expr: _expr, field }) => { + field.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(state); @@ -3282,11 +3292,15 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type})") + Expr::Cast(Cast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); + write!(f, "CAST({expr} AS {formatted})") } - Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type})") + Expr::TryCast(TryCast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); + write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3672,7 +3686,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: DataType::Utf8, + field: DataType::Utf8.into_nullable_field_ref(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index ec22be525464b..7c6af56c8d961 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection( if let Some(found) = found { return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { + Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast { expr: Box::new(found), - data_type, + field, }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast { expr: Box::new(found), - data_type, + field, }), _ => found, })); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 854e907d68b1a..6e7f9aa8452b0 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -132,8 +132,9 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + Ok(field.data_type().clone()) + } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list @@ -611,9 +612,23 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .to_field(schema) - .map(|(_, f)| f.retyped(data_type.clone())), + .map(|(_table_ref, destination_field)| { + // This propagates the nullability of the input rather than + // force the nullability of the destination field. This is + // usually the desired behaviour (i.e., specifying a cast + // destination type usually does not force a user to pick + // nullability, and assuming `true` would prevent the non-nullability + // of the parent expression to make the result eligible for + // optimizations that only apply to non-nullable values). + destination_field + .as_ref() + .clone() + .with_data_type(field.data_type().clone()) + .with_metadata(destination_field.metadata().clone()) + }) + .map(Arc::new), Expr::Placeholder(Placeholder { id: _, field: Some(field), diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 742bae5b2320b..6c1a5702e2179 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -220,12 +220,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))), + Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 7c24450adf183..e555081e4132c 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,11 +19,11 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; -use datafusion_common::types::logical_string; use datafusion_common::{ - Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err, + Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt, + exec_datafusion_err, exec_err, internal_err, types::logical_string, + utils::take_function_args, }; -use datafusion_common::{exec_datafusion_err, utils::take_function_args}; use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; @@ -176,7 +176,7 @@ impl ScalarUDFImpl for ArrowCastFunc { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: target_type, + field: target_type.into_nullable_field_ref(), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 2c78051c14134..7fd0dcba2af0d 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -289,8 +289,8 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( + Expr::Cast(Cast { expr, field: _ }) + | Expr::TryCast(TryCast { expr, field: _ }) => extract_non_nullable_columns( expr, non_nullable_cols, left_schema, diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 84a6aa4309872..5c170700d9833 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,7 +25,7 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; -use datafusion_common::metadata::FieldMetadata; +use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err, }; @@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::VarType; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::{ - Between, BinaryExpr, Expr, Like, Operator, TryCast, binary_expr, lit, + Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, binary_expr, lit, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -288,16 +288,44 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - data_type.clone(), - ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - data_type.clone(), - ), + Expr::Cast(Cast { expr, field }) => { + if !field.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "Cast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata(field.data_type(), Some(field.metadata())) + ); + } + + expressions::cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + field.data_type().clone(), + ) + } + Expr::TryCast(TryCast { expr, field }) => { + if !field.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "TryCast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata(field.data_type(), Some(field.metadata())) + ); + } + + expressions::try_cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + field.data_type().clone(), + ) + } Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } @@ -417,7 +445,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; - + use datafusion_common::datatype::DataTypeExt; use datafusion_expr::{Operator, col, lit}; use super::*; @@ -447,6 +475,41 @@ mod tests { Ok(()) } + #[test] + fn test_cast_to_extension_type() -> Result<()> { + let extension_field_type = Arc::new( + DataType::FixedSizeBinary(16) + .into_nullable_field() + .with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + ); + let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(expr.clone()), + Arc::clone(&extension_field_type), + )); + let err = + create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + + let try_cast_expr = Expr::TryCast(TryCast::new_from_field( + Box::new(expr.clone()), + Arc::clone(&extension_field_type), + )); + let err = create_physical_expr( + &try_cast_expr, + &DFSchema::empty(), + &ExecutionProps::new(), + ) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + + Ok(()) + } + /// Test that deeply nested expressions do not cause a stack overflow. /// /// This test only runs when the `recursive_protection` feature is enabled, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd7dd3a6aff3c..8defbb31a12cf 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -593,11 +593,15 @@ message WhenThen { message CastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message TryCastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message SortExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e269606d163a3..6fa12f91284e4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1979,6 +1979,12 @@ impl serde::Serialize for CastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -1986,6 +1992,12 @@ impl serde::Serialize for CastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -1999,12 +2011,16 @@ impl<'de> serde::Deserialize<'de> for CastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2028,6 +2044,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2049,6 +2067,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -2063,11 +2083,27 @@ impl<'de> serde::Deserialize<'de> for CastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(CastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } @@ -22632,6 +22668,12 @@ impl serde::Serialize for TryCastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.TryCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -22639,6 +22681,12 @@ impl serde::Serialize for TryCastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -22652,12 +22700,16 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22681,6 +22733,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22702,6 +22756,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -22716,11 +22772,27 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(TryCastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cf343e0258d0b..a99145c38a53e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -919,6 +919,13 @@ pub struct CastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TryCastNode { @@ -926,6 +933,13 @@ pub struct TryCastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 179fe8bb7d7fe..2a433ff09fe03 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,7 +17,8 @@ use std::sync::Arc; -use arrow::datatypes::Field; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, @@ -526,8 +527,11 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); + Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field)))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -536,8 +540,14 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); + Ok(Expr::TryCast(TryCast::new_from_field( + expr, + Arc::new(field), + ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6e4e5d0b6eea4..d989caef818da 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -521,19 +521,23 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, field }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5746a568e712b..42dff62fcc210 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -37,6 +37,7 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, + Field, FieldRef, }; use arrow::util::display::array_value_to_string; use datafusion_common::{ @@ -188,9 +189,7 @@ impl Unparser<'_> { end_token: AttachedToken::empty(), }) } - Expr::Cast(Cast { expr, data_type }) => { - Ok(self.cast_to_sql(expr, data_type)?) - } + Expr::Cast(Cast { expr, field }) => Ok(self.cast_to_sql(expr, field)?), Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(window_fun) => { @@ -461,12 +460,12 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }) } @@ -1146,24 +1145,27 @@ impl Unparser<'_> { // Explicit type cast on ast::Expr::Value is not needed by underlying engine for certain types // For example: CAST(Utf8("binary_value") AS Binary) and CAST(Utf8("dictionary_value") AS Dictionary) - fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> Result { + fn cast_to_sql(&self, expr: &Expr, field: &FieldRef) -> Result { let inner_expr = self.expr_to_sql_inner(expr)?; + let data_type = field.data_type(); match inner_expr { ast::Expr::Value(_) => match data_type { - DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView => { + DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView + if field.metadata().is_empty() => + { Ok(inner_expr) } _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), }, _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), } @@ -1689,7 +1691,8 @@ impl Unparser<'_> { })) } - fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { + fn arrow_dtype_to_ast_dtype(&self, field: &FieldRef) -> Result { + let data_type = field.data_type(); match data_type { DataType::Null => { not_impl_err!("Unsupported DataType: conversion: {data_type}") @@ -1762,7 +1765,9 @@ impl Unparser<'_> { DataType::Union(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype( + &Field::new("", val.as_ref().clone(), true).into(), + ), DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1903,34 +1908,25 @@ mod tests { r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)), r#"CAST(a AS DATETIME)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+08:00".into()), - ), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), + )), r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Millisecond, None), + )), r#"CAST(a AS TIMESTAMP)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::UInt32, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::UInt32)), r#"CAST(a AS INTEGER UNSIGNED)"#, ), ( @@ -2248,10 +2244,7 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Decimal128(10, -2), - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Decimal128(10, -2))), r#"CAST(a AS DECIMAL(12,0))"#, ), ( @@ -2388,10 +2381,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2413,10 +2403,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2646,23 +2633,23 @@ mod tests { fn test_cast_value_to_binary_expr() { let tests = [ ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::Binary, - }), + DataType::Binary, + )), "'blah'", ), ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::BinaryView, - }), + DataType::BinaryView, + )), "'blah'", ), ]; @@ -2693,10 +2680,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2779,10 +2763,7 @@ mod tests { [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2807,10 +2788,7 @@ mod tests { [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int32, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2846,10 +2824,7 @@ mod tests { (&mysql_dialect, ×tamp_with_tz, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: data_type.clone(), - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type.clone())); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2902,10 +2877,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2945,13 +2917,13 @@ mod tests { #[test] fn test_cast_value_to_dict_expr() { let tests = [( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("variation".to_string())), None, )), - data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), - }), + DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), + )), "'variation'", )]; for (value, expected) in tests { @@ -2983,10 +2955,7 @@ mod tests { datafusion_functions::math::round::RoundFunc::new(), )), args: vec![ - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); @@ -3148,10 +3117,12 @@ mod tests { let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ))?; + let arrow_field = Arc::new(Field::new( + "", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Varchar(None)); @@ -3165,7 +3136,8 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?; + let arrow_field = Arc::new(Field::new("", DataType::Utf8View, true)); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Char(None)); @@ -3233,10 +3205,10 @@ mod tests { let dialect: Arc = Arc::new(SqliteDialect {}); let unparser = Unparser::new(dialect.as_ref()); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Nanosecond, None), - }); + let expr = Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, None), + )); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs index ec70ac3fec340..3dd62afe8f193 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::consumer::SubstraitConsumer; -use crate::logical_plan::consumer::types::from_substrait_type_without_names; +use crate::logical_plan::consumer::{ + SubstraitConsumer, field_from_substrait_type_without_names, +}; use datafusion::common::{DFSchema, substrait_err}; use datafusion::logical_expr::{Cast, Expr, TryCast}; use substrait::proto::expression as substrait_expression; @@ -37,11 +38,11 @@ pub async fn from_cast( ) .await?, ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; + let field = field_from_substrait_type_without_names(consumer, output_type)?; if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + Ok(Expr::TryCast(TryCast::new_from_field(input_expr, field))) } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) + Ok(Expr::Cast(Cast::new_from_field(input_expr, field))) } } None => substrait_err!("Cast expression without output type is not allowed"), diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index eb2cc967ca236..205fd78f0d86b 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -34,14 +34,22 @@ use crate::variation_const::{ }; use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME}; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::datatype::DataTypeExt; use datafusion::common::{ DFSchema, not_impl_err, substrait_datafusion_err, substrait_err, }; use std::sync::Arc; use substrait::proto::{NamedStruct, Type, r#type}; +pub(crate) fn field_from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + Ok(from_substrait_type_without_names(consumer, dt)?.into_nullable_field_ref()) +} + pub(crate) fn from_substrait_type_without_names( consumer: &impl SubstraitConsumer, dt: &Type, @@ -49,6 +57,16 @@ pub(crate) fn from_substrait_type_without_names( from_substrait_type(consumer, dt, &[], &mut 0) } +pub fn field_from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + // We could add nullability here now that we are returning a Field + Ok(from_substrait_type(consumer, dt, dfs_names, name_idx)?.into_nullable_field_ref()) +} + pub fn from_substrait_type( consumer: &impl SubstraitConsumer, dt: &Type, diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 53d3d3e12c4bf..e2140a9a2254f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type}; +use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type_from_field}; use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; use datafusion::common::{DFSchemaRef, ScalarValue}; use datafusion::logical_expr::{Cast, Expr, TryCast}; @@ -29,7 +29,7 @@ pub fn from_cast( cast: &Cast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let Cast { expr, data_type } = cast; + let Cast { expr, field } = cast; // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null if let Expr::Literal(lit, _) = expr.as_ref() { // only the untyped(a null scalar value) null literal need this special handling @@ -39,8 +39,8 @@ pub fn from_cast( let lit = Literal { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - producer, data_type, true, + literal_type: Some(LiteralType::Null(to_substrait_type_from_field( + producer, field, )?)), }; return Ok(Expression { @@ -51,7 +51,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -64,11 +64,11 @@ pub fn from_try_cast( cast: &TryCast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let TryCast { expr, data_type } = cast; + let TryCast { expr, field } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, @@ -80,7 +80,7 @@ pub fn from_try_cast( mod tests { use super::*; use crate::logical_plan::producer::{ - DefaultSubstraitProducer, to_substrait_extended_expr, + DefaultSubstraitProducer, to_substrait_extended_expr, to_substrait_type, }; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema; diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 3727596119bc3..fa58949e6ecd2 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -27,7 +27,7 @@ use crate::variation_const::{ TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; -use datafusion::arrow::datatypes::{DataType, IntervalUnit}; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; use datafusion::common::{DFSchemaRef, not_impl_err, plan_err}; use substrait::proto::{NamedStruct, r#type}; @@ -36,12 +36,19 @@ pub(crate) fn to_substrait_type( dt: &DataType, nullable: bool, ) -> datafusion::common::Result { - let nullability = if nullable { + to_substrait_type_from_field(producer, &Field::new("", dt.clone(), nullable).into()) +} + +pub(crate) fn to_substrait_type_from_field( + producer: &mut impl SubstraitProducer, + field: &FieldRef, +) -> datafusion::common::Result { + let nullability = if field.is_nullable() { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; - match dt { + match field.data_type() { DataType::Null => { let type_anchor = producer.register_type(NULL_TYPE_NAME.to_string()); Ok(substrait::proto::Type { @@ -288,16 +295,9 @@ pub(crate) fn to_substrait_type( } DataType::Map(inner, _) => match inner.data_type() { DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - producer, - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - producer, - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; + let key_type = to_substrait_type_from_field(producer, &key_and_value[0])?; + let value_type = + to_substrait_type_from_field(producer, &key_and_value[1])?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -310,8 +310,14 @@ pub(crate) fn to_substrait_type( _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, DataType::Dictionary(key_type, value_type) => { - let key_type = to_substrait_type(producer, key_type, nullable)?; - let value_type = to_substrait_type(producer, value_type, nullable)?; + let key_type = to_substrait_type_from_field( + producer, + &Field::new("", key_type.as_ref().clone(), field.is_nullable()).into(), + )?; + let value_type = to_substrait_type_from_field( + producer, + &Field::new("", value_type.as_ref().clone(), field.is_nullable()).into(), + )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -324,9 +330,7 @@ pub(crate) fn to_substrait_type( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| { - to_substrait_type(producer, field.data_type(), field.is_nullable()) - }) + .map(|field| to_substrait_type_from_field(producer, field)) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -352,7 +356,7 @@ pub(crate) fn to_substrait_type( precision: *p as i32, })), }), - _ => not_impl_err!("Unsupported cast type: {dt}"), + _ => not_impl_err!("Unsupported cast type: {field}"), } } @@ -369,7 +373,7 @@ pub(crate) fn to_substrait_named_struct( types: schema .fields() .iter() - .map(|f| to_substrait_type(producer, f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type_from_field(producer, f)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 157e0339e1eff..88912f7a7b933 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -767,6 +767,21 @@ See the [default column values example](https://github.com/apache/datafusion/blo If you implemented a custom `SchemaAdapterFactory`, migrate to `PhysicalExprAdapterFactory`. See the [default column values example](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs) for a complete implementation. +### `Expr::ScalarVariable` and `Expr::Cast` now store type information as a `FieldRef` + +Code that explicitly constructed `Expr::ScalarVariable(DataType::xxx)` must now convert +the `DataType` into a `FieldRef` (e.g., `Expr::ScalarVariable(DataType::Int8.into_nullable_field_ref())`). +Implementations of a custom `ContextProvider` may implement `get_variable_field()` to provide +planning information for variables with Arrow extension types or explicit nullability. See +[#18243](https://github.com/apache/datafusion/pull/18243) for more information. + +Similarly, the `Cast` type wrapped by `Expr::Cast` was updated to use a `FieldRef` as the +underlying storage to represent casts to Arrow extension types. Code that constructed +or matched casts in the form `Expr::Cast(Cast { expr, data_type })` will need to be updated. +Casts to `DataType` can continue to be constructed with `Expr::Cast(Cast::new(expr, DataType::...))`, +although we recommend `Expr::Cast(Cast::new_from_field(expr, existing_field_ref))` when +constructing casts programmatically to avoid dropping extension type metadata. + ## DataFusion `51.0.0` ### `arrow` / `parquet` updated to 57.0.0