diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a595b59355739..088880e28b8f4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -316,7 +316,7 @@ pub enum Expr { /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(FieldRef, Vec), /// A constant value along with associated [`FieldMetadata`]. Literal(ScalarValue, Option), /// A binary expression such as "age > 21" @@ -2529,8 +2529,8 @@ impl HashNode for Expr { Expr::Column(column) => { column.hash(state); } - Expr::ScalarVariable(data_type, name) => { - data_type.hash(state); + Expr::ScalarVariable(field, name) => { + field.hash(state); name.hash(state); } Expr::Literal(scalar_value, _) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0d895310655ca..691a8c508f801 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -121,7 +121,7 @@ impl ExprSchemable for Expr { Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), - Expr::ScalarVariable(ty, _) => Ok(ty.clone()), + Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()), Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { @@ -365,12 +365,8 @@ impl ExprSchemable for Expr { window_function, ) .map(|(_, nullable)| nullable), - Expr::Placeholder(Placeholder { id: _, field }) => { - Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true)) - } - Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => { - Ok(true) - } + Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), + Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -503,9 +499,7 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => { Ok(Arc::clone(field).renamed(&schema_name)) } - Expr::ScalarVariable(ty, _) => { - Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) - } + Expr::ScalarVariable(field, _) => Ok(Arc::clone(field).renamed(&schema_name)), Expr::Literal(l, metadata) => Ok(Arc::new( Field::new(&schema_name, l.data_type(), l.is_null()) .with_field_metadata_opt(metadata.as_ref()), @@ -1206,4 +1200,21 @@ mod tests { Ok(&self.field) } } + + #[test] + fn test_scalar_variable() { + let mut meta = HashMap::new(); + meta.insert("bar".to_string(), "buzz".to_string()); + let meta = FieldMetadata::from(meta); + + let field = Field::new("foo", DataType::Int32, true); + let field = meta.add_to_field(field); + let field = Arc::new(field); + + let expr = Expr::ScalarVariable(field, vec!["foo".to_string()]); + + let schema = MockExprSchema::new(); + + assert_eq!(meta, expr.metadata(&schema).unwrap()); + } } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 794c394d11d49..0c5dfa2ee1c97 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -25,7 +25,8 @@ use crate::{ AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF, }; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ DFSchema, Result, TableReference, config::ConfigOptions, file_options::file_type::FileType, not_impl_err, @@ -103,6 +104,17 @@ pub trait ContextProvider { /// A user defined variable is typically accessed via `@var_name` fn get_variable_type(&self, variable_names: &[String]) -> Option; + /// Return metadata about a system/user-defined variable, if any. + /// + /// By default, this wraps [`Self::get_variable_type`] in an Arrow [`Field`] + /// with nullable set to `true` and no metadata. Implementations that can + /// provide richer information (such as nullability or extension metadata) + /// should override this method. + fn get_variable_field(&self, variable_names: &[String]) -> Option { + self.get_variable_type(variable_names) + .map(|data_type| data_type.into_nullable_field_ref()) + } + /// Return overall configuration options fn options(&self) -> &ConfigOptions; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index b329cc5d1fe9b..79f65e3f59b32 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::FieldRef; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ assert_or_internal_err, exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, @@ -39,13 +40,18 @@ impl SqlToRel<'_, S> { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. let var_names = vec![id.value]; - let ty = self + let field = self .context_provider - .get_variable_type(&var_names) + .get_variable_field(&var_names) + .or_else(|| { + self.context_provider + .get_variable_type(&var_names) + .map(|ty| ty.into_nullable_field_ref()) + }) .ok_or_else(|| { plan_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(field, var_names)) } else { // Don't use `col()` here because it will try to // interpret names with '.' as if they were @@ -111,13 +117,18 @@ impl SqlToRel<'_, S> { .into_iter() .map(|id| self.ident_normalizer.normalize(id)) .collect(); - let ty = self + let field = self .context_provider - .get_variable_type(&var_names) + .get_variable_field(&var_names) + .or_else(|| { + self.context_provider + .get_variable_type(&var_names) + .map(|ty| ty.into_nullable_field_ref()) + }) .ok_or_else(|| { exec_datafusion_err!("variable {var_names:?} has no type information") })?; - Ok(Expr::ScalarVariable(ty, var_names)) + Ok(Expr::ScalarVariable(field, var_names)) } else { let ids = ids .into_iter() diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 575cfd27ee354..62e1927ccfa14 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1783,6 +1783,7 @@ mod tests { use arrow::array::{LargeListArray, ListArray}; use arrow::datatypes::{DataType::Int8, Field, Int32Type, Schema, TimeUnit}; use ast::ObjectName; + use datafusion_common::datatype::DataTypeExt; use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ @@ -2169,12 +2170,15 @@ mod tests { r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - Expr::ScalarVariable(Int8, vec![String::from("@a")]), + Expr::ScalarVariable( + Int8.into_nullable_field_ref(), + vec![String::from("@a")], + ), r#"@a"#, ), ( Expr::ScalarVariable( - Int8, + Int8.into_nullable_field_ref(), vec![String::from("@root"), String::from("foo")], ), r#"@root.foo"#,