Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl TableProvider for CustomProvider {
Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
Expr::Cast(Cast { expr, field: _ }) => match expr.deref() {
Expr::Literal(lit_value, _) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,7 @@ impl ScalarUDFImpl for CastToI64UDF {
arg
} else {
// need to use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: DataType::Int64,
})
Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64))
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
Expand Down
60 changes: 37 additions & 23 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
use datafusion_common::datatype::DataTypeExt;
use datafusion_common::metadata::format_type_and_metadata;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
Expand Down Expand Up @@ -796,13 +798,20 @@ pub struct Cast {
/// The expression being cast
pub expr: Box<Expr>,
/// The `DataType` the expression will yield
pub data_type: DataType,
pub field: FieldRef,
}

impl Cast {
/// Create a new Cast expression
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
Self {
expr,
field: data_type.into_nullable_field_ref(),
}
}

pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
Self { expr, field }
}
}

Expand All @@ -812,13 +821,20 @@ pub struct TryCast {
/// The expression being cast
pub expr: Box<Expr>,
/// The `DataType` the expression will yield
pub data_type: DataType,
pub field: FieldRef,
Comment on lines -815 to +824
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep data_type as a deprecated field that we populate from field.data_type() for a couple of releases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also implement this by adding a metadata: FieldMetadata field rather than switching the data_type to a FieldRef. We mostly just switched DataType to FieldRef in other places so that's what I did here.

}

impl TryCast {
/// Create a new TryCast expression
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
Self {
expr,
field: data_type.into_nullable_field_ref(),
}
}

pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
Self { expr, field }
}
}

Expand Down Expand Up @@ -2252,23 +2268,23 @@ impl NormalizeEq for Expr {
(
Expr::Cast(Cast {
expr: self_expr,
data_type: self_data_type,
field: self_field,
}),
Expr::Cast(Cast {
expr: other_expr,
data_type: other_data_type,
field: other_field,
}),
)
| (
Expr::TryCast(TryCast {
expr: self_expr,
data_type: self_data_type,
field: self_field,
}),
Expr::TryCast(TryCast {
expr: other_expr,
data_type: other_data_type,
field: other_field,
}),
) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr),
) => self_field == other_field && self_expr.normalize_eq(other_expr),
(
Expr::ScalarFunction(ScalarFunction {
func: self_func,
Expand Down Expand Up @@ -2584,15 +2600,9 @@ impl HashNode for Expr {
when_then_expr: _when_then_expr,
else_expr: _else_expr,
}) => {}
Expr::Cast(Cast {
expr: _expr,
data_type,
})
| Expr::TryCast(TryCast {
expr: _expr,
data_type,
}) => {
data_type.hash(state);
Expr::Cast(Cast { expr: _expr, field })
| Expr::TryCast(TryCast { expr: _expr, field }) => {
field.hash(state);
}
Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
func.hash(state);
Expand Down Expand Up @@ -3282,11 +3292,15 @@ impl Display for Expr {
}
write!(f, "END")
}
Expr::Cast(Cast { expr, data_type }) => {
write!(f, "CAST({expr} AS {data_type})")
Expr::Cast(Cast { expr, field }) => {
let formatted =
format_type_and_metadata(field.data_type(), Some(field.metadata()));
write!(f, "CAST({expr} AS {formatted})")
}
Expr::TryCast(TryCast { expr, data_type }) => {
write!(f, "TRY_CAST({expr} AS {data_type})")
Expr::TryCast(TryCast { expr, field }) => {
let formatted =
format_type_and_metadata(field.data_type(), Some(field.metadata()));
write!(f, "TRY_CAST({expr} AS {formatted})")
}
Expr::Not(expr) => write!(f, "NOT {expr}"),
Expr::Negative(expr) => write!(f, "(- {expr})"),
Expand Down Expand Up @@ -3672,7 +3686,7 @@ mod test {
fn format_cast() -> Result<()> {
let expr = Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)),
data_type: DataType::Utf8,
field: DataType::Utf8.into_nullable_field_ref(),
});
let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
assert_eq!(expected_canonical, format!("{expr}"));
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection(

if let Some(found) = found {
return Ok(Transformed::yes(match normalized_expr {
Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
expr: Box::new(found),
data_type,
field,
}),
Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast {
expr: Box::new(found),
data_type,
field,
}),
_ => found,
}));
Expand Down
23 changes: 19 additions & 4 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ impl ExprSchemable for Expr {
.as_ref()
.map_or(Ok(DataType::Null), |e| e.get_type(schema))
}
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => {
Ok(field.data_type().clone())
}
Expr::Unnest(Unnest { expr }) => {
let arg_data_type = expr.get_type(schema)?;
// Unnest's output type is the inner type of the list
Expand Down Expand Up @@ -611,9 +612,23 @@ impl ExprSchemable for Expr {
func.return_field_from_args(args)
}
// _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
Expr::Cast(Cast { expr, data_type }) => expr
Expr::Cast(Cast { expr, field }) => expr
.to_field(schema)
.map(|(_, f)| f.retyped(data_type.clone())),
.map(|(_table_ref, destination_field)| {
// This propagates the nullability of the input rather than
// force the nullability of the destination field. This is
// usually the desired behaviour (i.e., specifying a cast
// destination type usually does not force a user to pick
// nullability, and assuming `true` would prevent the non-nullability
// of the parent expression to make the result eligible for
// optimizations that only apply to non-nullable values).
destination_field
.as_ref()
.clone()
.with_data_type(field.data_type().clone())
.with_metadata(destination_field.metadata().clone())
})
.map(Arc::new),
Expr::Placeholder(Placeholder {
id: _,
field: Some(field),
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ impl TreeNode for Expr {
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
Expr::Cast(Cast { expr, data_type }) => expr
Expr::Cast(Cast { expr, field }) => expr
.map_elements(f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
Expr::TryCast(TryCast { expr, data_type }) => expr
.update_data(|be| Expr::Cast(Cast::new_from_field(be, field))),
Expr::TryCast(TryCast { expr, field }) => expr
.map_elements(f)?
.update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
.update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
args.map_elements(f)?.map_data(|new_args| {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::types::logical_string;
use datafusion_common::{
Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err,
Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt,
exec_datafusion_err, exec_err, internal_err, types::logical_string,
utils::take_function_args,
};
use datafusion_common::{exec_datafusion_err, utils::take_function_args};
use std::any::Any;

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
Expand Down Expand Up @@ -176,7 +176,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: target_type,
field: target_type.into_nullable_field_ref(),
})
};
// return the newly written argument to DataFusion
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/eliminate_outer_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ fn extract_non_nullable_columns(
false,
)
}
Expr::Cast(Cast { expr, data_type: _ })
| Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns(
Expr::Cast(Cast { expr, field: _ })
| Expr::TryCast(TryCast { expr, field: _ }) => extract_non_nullable_columns(
expr,
non_nullable_cols,
left_schema,
Expand Down
89 changes: 76 additions & 13 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{

use arrow::datatypes::Schema;
use datafusion_common::config::ConfigOptions;
use datafusion_common::metadata::FieldMetadata;
use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata};
use datafusion_common::{
DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, plan_err,
};
Expand All @@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction};
use datafusion_expr::var_provider::VarType;
use datafusion_expr::var_provider::is_system_variables;
use datafusion_expr::{
Between, BinaryExpr, Expr, Like, Operator, TryCast, binary_expr, lit,
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, binary_expr, lit,
};

/// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1
Expand Down Expand Up @@ -288,16 +288,44 @@ pub fn create_physical_expr(
};
Ok(expressions::case(expr, when_then_expr, else_expr)?)
}
Expr::Cast(Cast { expr, data_type }) => expressions::cast(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
data_type.clone(),
),
Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
data_type.clone(),
),
Expr::Cast(Cast { expr, field }) => {
if !field.metadata().is_empty() {
let (_, src_field) = expr.to_field(input_dfschema)?;
return plan_err!(
"Cast from {} to {} is not supported",
format_type_and_metadata(
src_field.data_type(),
Some(src_field.metadata()),
),
format_type_and_metadata(field.data_type(), Some(field.metadata()))
);
}

expressions::cast(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
field.data_type().clone(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does end up tying into #19097: I think they'd work well together, we'd just want to pass the field directly here.

)
}
Expr::TryCast(TryCast { expr, field }) => {
if !field.metadata().is_empty() {
let (_, src_field) = expr.to_field(input_dfschema)?;
return plan_err!(
"TryCast from {} to {} is not supported",
format_type_and_metadata(
src_field.data_type(),
Some(src_field.metadata()),
),
format_type_and_metadata(field.data_type(), Some(field.metadata()))
);
}

expressions::try_cast(
create_physical_expr(expr, input_dfschema, execution_props)?,
input_schema,
field.data_type().clone(),
)
}
Expr::Not(expr) => {
expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?)
}
Expand Down Expand Up @@ -417,7 +445,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
mod tests {
use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field};

use datafusion_common::datatype::DataTypeExt;
use datafusion_expr::{Operator, col, lit};

use super::*;
Expand Down Expand Up @@ -447,6 +475,41 @@ mod tests {
Ok(())
}

#[test]
fn test_cast_to_extension_type() -> Result<()> {
let extension_field_type = Arc::new(
DataType::FixedSizeBinary(16)
.into_nullable_field()
.with_metadata(
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
.into(),
),
);
let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58");
let cast_expr = Expr::Cast(Cast::new_from_field(
Box::new(expr.clone()),
Arc::clone(&extension_field_type),
));
let err =
create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new())
.unwrap_err();
assert!(err.message().contains("arrow.uuid"));

let try_cast_expr = Expr::TryCast(TryCast::new_from_field(
Box::new(expr.clone()),
Arc::clone(&extension_field_type),
));
let err = create_physical_expr(
&try_cast_expr,
&DFSchema::empty(),
&ExecutionProps::new(),
)
.unwrap_err();
assert!(err.message().contains("arrow.uuid"));

Ok(())
}

/// Test that deeply nested expressions do not cause a stack overflow.
///
/// This test only runs when the `recursive_protection` feature is enabled,
Expand Down
Loading