Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ pub enum Expr {
/// A named reference to a qualified field in a schema.
Column(Column),
/// A named reference to a variable in a registry.
ScalarVariable(DataType, Vec<String>),
ScalarVariable(FieldRef, Vec<String>),
/// A constant value along with associated [`FieldMetadata`].
Literal(ScalarValue, Option<FieldMetadata>),
/// A binary expression such as "age > 21"
Expand Down Expand Up @@ -2529,8 +2529,8 @@ impl HashNode for Expr {
Expr::Column(column) => {
column.hash(state);
}
Expr::ScalarVariable(data_type, name) => {
data_type.hash(state);
Expr::ScalarVariable(field, name) => {
field.hash(state);
name.hash(state);
}
Expr::Literal(scalar_value, _) => {
Expand Down
31 changes: 21 additions & 10 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl ExprSchemable for Expr {
Expr::Negative(expr) => expr.get_type(schema),
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()),
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()),
Expr::Literal(l, _) => Ok(l.data_type()),
Expr::Case(case) => {
for (_, then_expr) in &case.when_then_expr {
Expand Down Expand Up @@ -365,12 +365,8 @@ impl ExprSchemable for Expr {
window_function,
)
.map(|(_, nullable)| nullable),
Expr::Placeholder(Placeholder { id: _, field }) => {
Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true))
}
Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => {
Ok(true)
}
Expr::ScalarVariable(field, _) => Ok(field.is_nullable()),
Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
| Expr::IsTrue(_)
Expand Down Expand Up @@ -503,9 +499,7 @@ impl ExprSchemable for Expr {
Expr::OuterReferenceColumn(field, _) => {
Ok(Arc::clone(field).renamed(&schema_name))
}
Expr::ScalarVariable(ty, _) => {
Ok(Arc::new(Field::new(&schema_name, ty.clone(), true)))
}
Expr::ScalarVariable(field, _) => Ok(Arc::clone(field).renamed(&schema_name)),
Expr::Literal(l, metadata) => Ok(Arc::new(
Field::new(&schema_name, l.data_type(), l.is_null())
.with_field_metadata_opt(metadata.as_ref()),
Expand Down Expand Up @@ -1206,4 +1200,21 @@ mod tests {
Ok(&self.field)
}
}

#[test]
fn test_scalar_variable() {
let mut meta = HashMap::new();
meta.insert("bar".to_string(), "buzz".to_string());
let meta = FieldMetadata::from(meta);

let field = Field::new("foo", DataType::Int32, true);
let field = meta.add_to_field(field);
let field = Arc::new(field);

let expr = Expr::ScalarVariable(field, vec!["foo".to_string()]);

let schema = MockExprSchema::new();

assert_eq!(meta, expr.metadata(&schema).unwrap());
}
}
14 changes: 13 additions & 1 deletion datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use crate::{
AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame,
WindowFunctionDefinition, WindowUDF,
};
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
use datafusion_common::datatype::DataTypeExt;
use datafusion_common::{
DFSchema, Result, TableReference, config::ConfigOptions,
file_options::file_type::FileType, not_impl_err,
Expand Down Expand Up @@ -103,6 +104,17 @@ pub trait ContextProvider {
/// A user defined variable is typically accessed via `@var_name`
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType>;

/// Return metadata about a system/user-defined variable, if any.
///
/// By default, this wraps [`Self::get_variable_type`] in an Arrow [`Field`]
/// with nullable set to `true` and no metadata. Implementations that can
/// provide richer information (such as nullability or extension metadata)
/// should override this method.
fn get_variable_field(&self, variable_names: &[String]) -> Option<FieldRef> {
self.get_variable_type(variable_names)
.map(|data_type| data_type.into_nullable_field_ref())
}

/// Return overall configuration options
fn options(&self) -> &ConfigOptions;

Expand Down
23 changes: 17 additions & 6 deletions datafusion/sql/src/expr/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use arrow::datatypes::FieldRef;
use datafusion_common::datatype::DataTypeExt;
use datafusion_common::{
assert_or_internal_err, exec_datafusion_err, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference,
Expand All @@ -39,13 +40,18 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
if id.value.starts_with('@') {
// TODO: figure out if ScalarVariables should be insensitive.
let var_names = vec![id.value];
let ty = self
let field = self
.context_provider
.get_variable_type(&var_names)
.get_variable_field(&var_names)
.or_else(|| {
self.context_provider
.get_variable_type(&var_names)
.map(|ty| ty.into_nullable_field_ref())
})
.ok_or_else(|| {
plan_datafusion_err!("variable {var_names:?} has no type information")
})?;
Ok(Expr::ScalarVariable(ty, var_names))
Ok(Expr::ScalarVariable(field, var_names))
} else {
// Don't use `col()` here because it will try to
// interpret names with '.' as if they were
Expand Down Expand Up @@ -111,13 +117,18 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.into_iter()
.map(|id| self.ident_normalizer.normalize(id))
.collect();
let ty = self
let field = self
.context_provider
.get_variable_type(&var_names)
.get_variable_field(&var_names)
.or_else(|| {
self.context_provider
.get_variable_type(&var_names)
.map(|ty| ty.into_nullable_field_ref())
})
.ok_or_else(|| {
exec_datafusion_err!("variable {var_names:?} has no type information")
})?;
Ok(Expr::ScalarVariable(ty, var_names))
Ok(Expr::ScalarVariable(field, var_names))
} else {
let ids = ids
.into_iter()
Expand Down
8 changes: 6 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,7 @@ mod tests {
use arrow::array::{LargeListArray, ListArray};
use arrow::datatypes::{DataType::Int8, Field, Int32Type, Schema, TimeUnit};
use ast::ObjectName;
use datafusion_common::datatype::DataTypeExt;
use datafusion_common::{Spans, TableReference};
use datafusion_expr::expr::WildcardOptions;
use datafusion_expr::{
Expand Down Expand Up @@ -2169,12 +2170,15 @@ mod tests {
r#"TRY_CAST(a AS INTEGER UNSIGNED)"#,
),
(
Expr::ScalarVariable(Int8, vec![String::from("@a")]),
Expr::ScalarVariable(
Int8.into_nullable_field_ref(),
vec![String::from("@a")],
),
r#"@a"#,
),
(
Expr::ScalarVariable(
Int8,
Int8.into_nullable_field_ref(),
vec![String::from("@root"), String::from("foo")],
),
r#"@root.foo"#,
Expand Down