From c26639d5ba033acdbc27d372d87b1b2b5dd3bbbe Mon Sep 17 00:00:00 2001 From: kould Date: Thu, 21 Aug 2025 14:23:07 +0800 Subject: [PATCH 1/2] feat: implement Scalar UDF --- src/meta/app/src/principal/mod.rs | 1 + .../src/principal/user_defined_function.rs | 30 ++++++++ .../src/udf_from_to_protobuf_impl.rs | 75 +++++++++++++++++++ src/meta/proto-conv/src/util.rs | 1 + src/meta/proto-conv/tests/it/main.rs | 1 + .../proto-conv/tests/it/v144_scalar_udf.rs | 64 ++++++++++++++++ src/meta/protos/proto/udf.proto | 14 ++++ src/query/ast/src/ast/statements/udf.rs | 17 +++++ src/query/ast/src/parser/statement.rs | 46 ++++++++++-- .../bind_table_function.rs | 58 +++----------- src/query/sql/src/planner/binder/udf.rs | 27 +++++++ src/query/sql/src/planner/semantic/mod.rs | 1 + .../sql/src/planner/semantic/type_check.rs | 56 ++++++++++++++ .../sql/src/planner/semantic/udf_rewriter.rs | 41 ++++++++++ .../system/src/user_functions_table.rs | 12 +++ .../base/03_common/03_0013_select_udf.test | 60 +++++++++++++++ 16 files changed, 451 insertions(+), 53 deletions(-) create mode 100644 src/meta/proto-conv/tests/it/v144_scalar_udf.rs diff --git a/src/meta/app/src/principal/mod.rs b/src/meta/app/src/principal/mod.rs index 8bdd1a9b4a876..08a2cef1a58dc 100644 --- a/src/meta/app/src/principal/mod.rs +++ b/src/meta/app/src/principal/mod.rs @@ -109,6 +109,7 @@ pub use user_auth::AuthType; pub use user_auth::PasswordHashMethod; pub use user_defined_file_format::UserDefinedFileFormat; pub use user_defined_function::LambdaUDF; +pub use user_defined_function::ScalarUDF; pub use user_defined_function::UDAFScript; pub use user_defined_function::UDFDefinition; pub use user_defined_function::UDFScript; diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index f865858ea7e86..f6d1210dba115 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -64,6 +64,19 @@ pub struct UDTF { pub sql: String, } +/// User Defined Scalar Function (ScalarUDF) +/// +/// # Fields +/// - `arg_types`: arg name with data type +/// - `return_type`: return data type +/// - `definition`: typically including the code or expression implementing the function logic +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ScalarUDF { + pub arg_types: Vec<(String, DataType)>, + pub return_type: DataType, + pub definition: String, +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct UDAFScript { pub code: String, @@ -85,6 +98,7 @@ pub enum UDFDefinition { UDFScript(UDFScript), UDAFScript(UDAFScript), UDTF(UDTF), + ScalarUDF(ScalarUDF), } impl UDFDefinition { @@ -95,6 +109,7 @@ impl UDFDefinition { Self::UDFScript(_) => "UDFScript", Self::UDAFScript(_) => "UDAFScript", Self::UDTF(_) => "UDTF", + UDFDefinition::ScalarUDF(_) => "ScalarUDF", } } @@ -104,6 +119,7 @@ impl UDFDefinition { Self::UDFServer(_) => false, Self::UDFScript(_) => false, Self::UDTF(_) => false, + Self::ScalarUDF(_) => false, Self::UDAFScript(_) => true, } } @@ -112,6 +128,7 @@ impl UDFDefinition { match self { Self::LambdaUDF(_) => "SQL", Self::UDTF(_) => "SQL", + Self::ScalarUDF(_) => "SQL", Self::UDFServer(x) => x.language.as_str(), Self::UDFScript(x) => x.language.as_str(), Self::UDAFScript(x) => x.language.as_str(), @@ -329,6 +346,19 @@ impl Display for UDFDefinition { } write!(f, ") AS $${sql}$$")?; } + UDFDefinition::ScalarUDF(ScalarUDF { + arg_types, + return_type, + definition, + }) => { + for (i, (name, ty)) in arg_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{name} {ty}")?; + } + write!(f, ") RETURNS {return_type} AS $${definition}$$")?; + } } Ok(()) } diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs index 58c41c9ed428e..fd850e5f4ff34 100644 --- a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -357,6 +357,75 @@ impl FromToProto for mt::UDTF { } } +impl FromToProto for mt::ScalarUDF { + type PB = pb::ScalarUdf; + + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + + fn from_pb(p: Self::PB) -> Result + where Self: Sized { + reader_check_msg(p.ver, p.min_reader_ver)?; + + let mut arg_types = Vec::new(); + for arg_ty in p.arg_types { + let ty_pb = arg_ty.ty.ok_or_else(|| { + Incompatible::new("ScalarUDF.arg_types.ty can not be None".to_string()) + })?; + let ty = TableDataType::from_pb(ty_pb)?; + + arg_types.push((arg_ty.name, (&ty).into())); + } + + let return_type_pb = p.return_type.ok_or_else(|| { + Incompatible::new("ScalarUDF.return_type can not be None".to_string()) + })?; + let return_type = TableDataType::from_pb(return_type_pb)?; + + Ok(Self { + arg_types, + return_type: (&return_type).into(), + definition: p.definition, + }) + } + + fn to_pb(&self) -> Result { + let mut arg_types = Vec::with_capacity(self.arg_types.len()); + for (arg_name, arg_type) in self.arg_types.iter() { + let arg_type = infer_schema_type(arg_type) + .map_err(|e| { + Incompatible::new(format!( + "Convert DataType to TableDataType failed: {}", + e.message() + )) + })? + .to_pb()?; + arg_types.push(UdtfArg { + name: arg_name.clone(), + ty: Some(arg_type), + }); + } + + let return_type = infer_schema_type(&self.return_type) + .map_err(|e| { + Incompatible::new(format!( + "Convert DataType to TableDataType failed: {}", + e.message() + )) + })? + .to_pb()?; + + Ok(pb::ScalarUdf { + ver: VER, + min_reader_ver: MIN_READER_VER, + arg_types, + return_type: Some(return_type), + definition: self.definition.clone(), + }) + } +} + impl FromToProto for mt::UserDefinedFunction { type PB = pb::UserDefinedFunction; fn get_pb_ver(p: &Self::PB) -> u64 { @@ -380,6 +449,9 @@ impl FromToProto for mt::UserDefinedFunction { Some(pb::user_defined_function::Definition::Udtf(udtf)) => { mt::UDFDefinition::UDTF(mt::UDTF::from_pb(udtf)?) } + Some(pb::user_defined_function::Definition::ScalarUdf(scalar_udf)) => { + mt::UDFDefinition::ScalarUDF(mt::ScalarUDF::from_pb(scalar_udf)?) + } None => { return Err(Incompatible::new( "UserDefinedFunction.definition cannot be None".to_string(), @@ -415,6 +487,9 @@ impl FromToProto for mt::UserDefinedFunction { mt::UDFDefinition::UDTF(udtf) => { pb::user_defined_function::Definition::Udtf(udtf.to_pb()?) } + mt::UDFDefinition::ScalarUDF(scalar_udf) => { + pb::user_defined_function::Definition::ScalarUdf(scalar_udf.to_pb()?) + } }; Ok(pb::UserDefinedFunction { diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index ed95f902e2099..df34e479ca181 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -173,6 +173,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (141, "2025-08-06: Add: row_access.proto"), (142, "2025-08-15: Add: table_meta add row_access_policy"), (143, "2025-08-18: Add: add UDTF"), + (144, "2025-08-18: Add: add ScalarUDF"), // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index 4fc568f4977a4..810a169a20144 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -135,3 +135,4 @@ mod v140_task_message; mod v141_row_access_policy; mod v142_table_row_access_policy; mod v143_udtf; +mod v144_scalar_udf; diff --git a/src/meta/proto-conv/tests/it/v144_scalar_udf.rs b/src/meta/proto-conv/tests/it/v144_scalar_udf.rs new file mode 100644 index 0000000000000..9f743e343a042 --- /dev/null +++ b/src/meta/proto-conv/tests/it/v144_scalar_udf.rs @@ -0,0 +1,64 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::DateTime; +use chrono::Utc; +use databend_common_expression::types::DataType; +use databend_common_meta_app::principal::ScalarUDF; +use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UserDefinedFunction; +use fastrace::func_name; + +use crate::common; + +// These bytes are built when a new version in introduced, +// and are kept for backward compatibility test. +// +// ************************************************************* +// * These messages should never be updated, * +// * only be added when a new version is added, * +// * or be removed when an old version is no longer supported. * +// ************************************************************* +// +// The message bytes are built from the output of `test_pb_from_to()` +#[test] +fn test_decode_v144_scalar_udf() -> anyhow::Result<()> { + let bytes = vec![ + 10, 15, 116, 101, 115, 116, 95, 115, 99, 97, 108, 97, 114, 95, 117, 100, 102, 18, 21, 84, + 104, 105, 115, 32, 105, 115, 32, 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, + 110, 74, 69, 10, 16, 10, 2, 99, 49, 18, 10, 146, 2, 0, 160, 6, 144, 1, 168, 6, 24, 10, 16, + 10, 2, 99, 50, 18, 10, 138, 2, 0, 160, 6, 144, 1, 168, 6, 24, 18, 10, 170, 2, 0, 160, 6, + 144, 1, 168, 6, 24, 26, 12, 67, 85, 82, 82, 69, 78, 84, 95, 68, 65, 84, 69, 160, 6, 144, 1, + 168, 6, 24, 42, 23, 50, 48, 50, 51, 45, 49, 50, 45, 49, 53, 32, 48, 49, 58, 50, 54, 58, 48, + 57, 32, 85, 84, 67, 160, 6, 144, 1, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "test_scalar_udf".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::ScalarUDF(ScalarUDF { + arg_types: vec![(s("c1"), DataType::String), (s("c2"), DataType::Boolean)], + return_type: DataType::Date, + definition: "CURRENT_DATE".to_string(), + }), + created_on: DateTime::::from_timestamp(1702603569, 0).unwrap(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 144, want()) +} + +fn s(ss: impl ToString) -> String { + ss.to_string() +} diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto index a3d59bf710585..7ae97266617d7 100644 --- a/src/meta/protos/proto/udf.proto +++ b/src/meta/protos/proto/udf.proto @@ -85,6 +85,19 @@ message UDTF { string sql = 3; } +message ScalarUDF { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + // arg name with data type + repeated UDTFArg arg_types = 1; + // return data type + DataType return_type = 2; + // typically including the code or expression implementing the function logic + string definition = 3; +} + + message UserDefinedFunction { uint64 ver = 100; uint64 min_reader_ver = 101; @@ -97,6 +110,7 @@ message UserDefinedFunction { UDFScript udf_script = 6; UDAFScript udaf_script = 7; UDTF udtf = 8; + ScalarUDF scalar_udf = 9; } // The time udf created. optional string created_on = 5; diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs index 0d0f6c4181cd3..9215594afc227 100644 --- a/src/query/ast/src/ast/statements/udf.rs +++ b/src/query/ast/src/ast/statements/udf.rs @@ -76,6 +76,11 @@ pub enum UDFDefinition { return_types: Vec<(Identifier, TypeName)>, sql: String, }, + ScalarUDF { + arg_types: Vec<(Identifier, TypeName)>, + definition: String, + return_type: TypeName, + }, } impl Display for UDFDefinition { @@ -198,6 +203,18 @@ impl Display for UDFDefinition { )?; write!(f, ") AS $$\n{sql}\n$$")?; } + UDFDefinition::ScalarUDF { + arg_types, + definition, + return_type, + } => { + write!(f, "(")?; + write_comma_separated_list( + f, + arg_types.iter().map(|(name, ty)| format!("{name} {ty}")), + )?; + write!(f, ") RETURNS {return_type} AS $$\n{definition}\n$$")?; + } UDFDefinition::UDAFScript { arg_types, state_fields: state_types, diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index 5c70783b8fb51..5fbab34f63047 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -4975,6 +4975,31 @@ pub fn udf_script_or_address(i: Input) -> IResult<(String, bool)> { } pub fn udf_definition(i: Input) -> IResult { + enum ReturnBody { + Scalar(TypeName), + Table(Vec<(Identifier, TypeName)>), + } + + fn return_body(i: Input) -> IResult { + let scalar = map( + rule! { + #type_name + }, + ReturnBody::Scalar, + ); + let table = map( + rule! { + TABLE ~ "(" ~ #comma_separated_list0(udtf_arg) ~ ")" + }, + |(_, _, arg_types, _)| ReturnBody::Table(arg_types), + ); + + rule!( + #scalar: "" + | #table: "TABLE (, ...)" + )(i) + } + let lambda_udf = map( rule! { AS ~ "(" ~ #comma_separated_list0(ident) ~ ")" @@ -5049,16 +5074,23 @@ pub fn udf_definition(i: Input) -> IResult { }, ); - let udtf = map( + let scalar_udf_or_udtf = map( rule! { "(" ~ #comma_separated_list0(udtf_arg) ~ ")" - ~ RETURNS ~ TABLE ~ "(" ~ #comma_separated_list0(udtf_arg) ~ ")" + ~ RETURNS ~ ^#return_body ~ AS ~ ^#code_string }, - |(_, arg_types, _, _, _, _, return_types, _, _, sql)| UDFDefinition::UDTFSql { - arg_types, - return_types, - sql, + |(_, arg_types, _, _, return_body, _, sql)| match return_body { + ReturnBody::Scalar(return_type) => UDFDefinition::ScalarUDF { + arg_types, + definition: sql, + return_type, + }, + ReturnBody::Table(return_types) => UDFDefinition::UDTFSql { + arg_types, + return_types, + sql, + }, }, ); @@ -5126,7 +5158,7 @@ pub fn udf_definition(i: Input) -> IResult { #lambda_udf: "AS (, ...) -> " | #udaf: "(, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } " | #udf: "(, ...) RETURNS LANGUAGE HANDLER= { ADDRESS= | AS } " - | #udtf: "(, ...) RETURNS TABLE (, ...) AS }" + | #scalar_udf_or_udtf: "(, ...) RETURNS AS }" )(i) } diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs index 6df4c06fe110c..4a8350d4c7d97 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs @@ -32,19 +32,15 @@ use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::display::scalar_ref_to_string; -use databend_common_expression::types::convert_to_type_name; use databend_common_expression::types::NumberScalar; use databend_common_expression::FunctionKind; use databend_common_expression::Scalar; use databend_common_functions::BUILTIN_FUNCTIONS; use databend_common_meta_app::principal::UDFDefinition; -use databend_common_meta_app::principal::UDTF; use databend_common_storages_basic::ResultCacheMetaManager; use databend_common_storages_basic::ResultScan; use databend_common_users::UserApiProvider; use derive_visitor::DriveMut; -use derive_visitor::VisitorMut; -use itertools::Itertools; use crate::binder::scalar::ScalarBinder; use crate::binder::table_args::bind_table_args; @@ -63,6 +59,7 @@ use crate::plans::ScalarItem; use crate::BindContext; use crate::Planner; use crate::ScalarExpr; +use crate::UDFArgVisitor; impl Binder { /// Bind a table function. @@ -76,43 +73,6 @@ impl Binder { alias: &Option, sample: &Option, ) -> Result<(SExpr, BindContext)> { - #[derive(VisitorMut)] - #[visitor(Expr(enter))] - struct UDTFArgVisitor<'a> { - udtf: &'a UDTF, - table_args: &'a TableArgs, - } - - impl UDTFArgVisitor<'_> { - fn enter_expr(&mut self, expr: &mut Expr) { - if let Expr::ColumnRef { span, column } = expr { - if column.database.is_some() || column.table.is_some() { - return; - } - assert_eq!(self.udtf.arg_types.len(), self.table_args.positioned.len()); - let Some((pos, (_, ty))) = self - .udtf - .arg_types - .iter() - .find_position(|(name, _)| name == column.column.name()) - else { - return; - }; - *expr = Expr::Cast { - span: *span, - expr: Box::new(Expr::Literal { - span: *span, - value: Literal::String(scalar_ref_to_string( - &self.table_args.positioned[pos].as_ref(), - )), - }), - target_type: convert_to_type_name(ty), - pg_style: false, - }; - } - } - } - let func_name = normalize_identifier(name, &self.name_resolution_ctx); if BUILTIN_FUNCTIONS @@ -190,17 +150,23 @@ impl Binder { .statement; if udtf.arg_types.len() != table_args.positioned.len() { - return Err(ErrorCode::UDFSchemaMismatch(format!( + return Err(ErrorCode::SyntaxException(format!( "UDTF '{}' argument types length {} does not match input arguments length {}", func_name, udtf.arg_types.len(), table_args.positioned.len() ))); } - let mut visitor = UDTFArgVisitor { - udtf: &udtf, - table_args: &table_args, - }; + + let args_expr = table_args + .positioned + .iter() + .map(|scalar| Expr::Literal { + span: None, + value: Literal::String(scalar_ref_to_string(&scalar.as_ref())), + }) + .collect::>(); + let mut visitor = UDFArgVisitor::new(&udtf.arg_types, &args_expr); stmt.drive_mut(&mut visitor); let binder = Binder::new( diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs index 551242d5c3efd..8107a02842f02 100644 --- a/src/query/sql/src/planner/binder/udf.rs +++ b/src/query/sql/src/planner/binder/udf.rs @@ -27,6 +27,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::udf_client::UDFFlightClient; use databend_common_expression::DataField; use databend_common_meta_app::principal::LambdaUDF; +use databend_common_meta_app::principal::ScalarUDF; use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition as PlanUDFDefinition; use databend_common_meta_app::principal::UDFScript; @@ -236,6 +237,32 @@ impl Binder { created_on: Utc::now(), }) } + UDFDefinition::ScalarUDF { + arg_types, + definition, + return_type, + } => { + let arg_types = arg_types + .iter() + .map(|(name, arg_type)| { + let column = normalize_identifier(name, &self.name_resolution_ctx).name; + let ty = DataType::from(&resolve_type_name_udf(arg_type)?); + Ok((column, ty)) + }) + .collect::>>()?; + let return_type = DataType::from(&resolve_type_name_udf(return_type)?); + + Ok(UserDefinedFunction { + name, + description, + definition: PlanUDFDefinition::ScalarUDF(ScalarUDF { + arg_types, + return_type, + definition: definition.clone(), + }), + created_on: Utc::now(), + }) + } } } diff --git a/src/query/sql/src/planner/semantic/mod.rs b/src/query/sql/src/planner/semantic/mod.rs index c0ef8d5f662e8..4164d8f66d802 100644 --- a/src/query/sql/src/planner/semantic/mod.rs +++ b/src/query/sql/src/planner/semantic/mod.rs @@ -48,6 +48,7 @@ pub use type_check::resolve_type_name_by_str; pub use type_check::resolve_type_name_udf; pub use type_check::validate_function_arg; pub use type_check::TypeChecker; +pub use udf_rewriter::UDFArgVisitor; pub(crate) use udf_rewriter::UdfRewriter; pub use view_rewriter::ViewRewriter; pub use window_check::WindowChecker; diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 7db66cddd60b3..000dbcbfa9a82 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -107,6 +107,7 @@ use databend_common_functions::RANK_WINDOW_FUNCTIONS; use databend_common_license::license::Feature; use databend_common_license::license_manager::LicenseManagerSwitch; use databend_common_meta_app::principal::LambdaUDF; +use databend_common_meta_app::principal::ScalarUDF; use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition; use databend_common_meta_app::principal::UDFScript; @@ -191,6 +192,7 @@ use crate::ColumnEntry; use crate::DefaultExprBinder; use crate::IndexType; use crate::MetadataRef; +use crate::UDFArgVisitor; const DEFAULT_DECIMAL_PRECISION: i64 = 38; const DEFAULT_DECIMAL_SCALE: i64 = 0; @@ -4794,6 +4796,9 @@ impl<'a> TypeChecker<'a> { self.resolve_udaf_script(span, name, arguments, udf_def)?, )), UDFDefinition::UDTF(_) => unreachable!(), + UDFDefinition::ScalarUDF(udf_def) => Ok(Some( + self.resolve_scalar_udf(span, name, arguments, udf_def)?, + )), } } @@ -5188,6 +5193,57 @@ impl<'a> TypeChecker<'a> { ))) } + fn resolve_scalar_udf( + &mut self, + span: Span, + func_name: String, + arguments: &[Expr], + udf_definition: ScalarUDF, + ) -> Result> { + let arg_types = udf_definition.arg_types; + if arg_types.len() != arguments.len() { + return Err(ErrorCode::SyntaxException(format!( + "Require {} parameters, but got: {}", + arg_types.len(), + arguments.len() + )) + .set_span(span)); + } + let settings = self.ctx.get_settings(); + let sql_dialect = settings.get_sql_dialect()?; + let sql_tokens = tokenize_sql(udf_definition.definition.as_str())?; + let mut udf_expr = parse_expr(&sql_tokens, sql_dialect)?; + let mut visitor = UDFArgVisitor::new(&arg_types, arguments); + udf_expr.drive_mut(&mut visitor); + + // independent context + let box (expr, _) = TypeChecker::try_create( + &mut BindContext::new(), + self.ctx.clone(), + &NameResolutionContext::default(), + MetadataRef::default(), + &[], + self.forbid_udf, + )? + .resolve(&udf_expr)?; + let return_ty = udf_definition.return_type; + let expr = CastExpr { + span, + is_try: false, + argument: Box::new(expr), + target_type: Box::new(return_ty.clone()), + }; + Ok(Box::new(( + UDFLambdaCall { + span, + func_name, + scalar: Box::new(expr.into()), + } + .into(), + return_ty, + ))) + } + fn resolve_async_function( &mut self, span: Span, diff --git a/src/query/sql/src/planner/semantic/udf_rewriter.rs b/src/query/sql/src/planner/semantic/udf_rewriter.rs index 31ec866119f7b..969e0e83f6926 100644 --- a/src/query/sql/src/planner/semantic/udf_rewriter.rs +++ b/src/query/sql/src/planner/semantic/udf_rewriter.rs @@ -16,8 +16,13 @@ use std::collections::HashMap; use std::collections::VecDeque; use std::sync::Arc; +use databend_common_ast::ast::Expr; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::convert_to_type_name; +use databend_common_expression::types::DataType; +use derive_visitor::VisitorMut as StatementVisitorMut; +use itertools::Itertools; use crate::optimizer::ir::SExpr; use crate::plans::walk_expr_mut; @@ -238,3 +243,39 @@ impl<'a> VisitorMut<'a> for UdfRewriter { Ok(()) } } + +#[derive(StatementVisitorMut)] +#[visitor(Expr(enter))] +pub struct UDFArgVisitor<'a> { + arg_types: &'a [(String, DataType)], + args: &'a [Expr], +} + +impl<'a> UDFArgVisitor<'a> { + pub fn new(arg_types: &'a [(String, DataType)], args: &'a [Expr]) -> Self { + Self { arg_types, args } + } + + fn enter_expr(&mut self, expr: &mut Expr) { + if let Expr::ColumnRef { span, column } = expr { + if column.database.is_some() || column.table.is_some() { + return; + } + assert_eq!(self.arg_types.len(), self.args.len()); + let Some((pos, (_, ty))) = self + .arg_types + .iter() + .find_position(|(name, _)| name == column.column.name()) + else { + return; + }; + + *expr = Expr::Cast { + span: *span, + expr: Box::new(self.args[pos].clone()), + target_type: convert_to_type_name(ty), + pg_style: false, + } + } + } +} diff --git a/src/query/storages/system/src/user_functions_table.rs b/src/query/storages/system/src/user_functions_table.rs index 758df89bf70d7..374a84efb92ab 100644 --- a/src/query/storages/system/src/user_functions_table.rs +++ b/src/query/storages/system/src/user_functions_table.rs @@ -228,6 +228,18 @@ impl UserFunctionsTable { states: BTreeMap::new(), immutable: None, }, + UDFDefinition::ScalarUDF(x) => UserFunctionArguments { + arg_types: x + .arg_types + .iter() + .map(|(name, ty)| format!("{name} {ty}")) + .collect(), + return_type: Some(x.return_type.to_string()), + server: None, + parameters: vec![], + states: Default::default(), + immutable: None, + }, }, }) .collect()) diff --git a/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test b/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test index 7a2b2464e2442..748954be1898d 100644 --- a/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test +++ b/tests/sqllogictests/suites/base/03_common/03_0013_select_udf.test @@ -256,3 +256,63 @@ query T select * from filter_t2('hello'); ---- hello + +statement ok +CREATE OR REPLACE FUNCTION reverse_str(s STRING) RETURNS STRING AS $$ REVERSE(s) $$; + +query T +SELECT reverse_str('hello'); +---- +olleh + +query T +SELECT reverse_str('he' || 'llo'); +---- +olleh + +statement ok +CREATE OR REPLACE FUNCTION str_to_int(s STRING) RETURNS INT32 AS $$ s $$; + +statement error +SELECT str_to_int('hello'); + +query I +SELECT str_to_int('10'); +---- +10 + +statement ok +CREATE OR REPLACE FUNCTION str_len(s STRING) RETURNS INT AS $$ LENGTH(s) $$; + +query I +SELECT str_len('hello'); +---- +5 + +query I +SELECT str_len('he' || 'llo'); +---- +5 + +statement ok +CREATE OR REPLACE FUNCTION bool_not(b BOOLEAN) RETURNS BOOLEAN AS $$ NOT b $$; + +query B +SELECT bool_not(TRUE); +---- +0 + +query B +SELECT bool_not(FALSE); +---- +1 + +query B +SELECT bool_not(NULL); +---- +NULL + +query B +SELECT bool_not(1 = 0); +---- +1 From a0672c6578441662c143713c4497f51ff3a4c08c Mon Sep 17 00:00:00 2001 From: kould Date: Thu, 21 Aug 2025 17:25:02 +0800 Subject: [PATCH 2/2] test: add scalar udf parse test on `test_statement` --- src/query/ast/tests/it/parser.rs | 1 + src/query/ast/tests/it/testdata/stmt.txt | 40 ++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index f940912ad8a31..6b876d22218ba 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -830,6 +830,7 @@ SELECT * from s;"#, r#"attach table t 's3://a' connection=(access_key_id ='x' secret_access_key ='y' endpoint_url='http://127.0.0.1:9900')"#, r#"CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p));"#, r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace AS(p) -> not(is_null(p)) DESC = 'This is a description';"#, + r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace (p STRING) RETURNS BOOL AS $$ not(is_null(p)) $$;"#, r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' HEADERS = ('X-Authorization' = '123') ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE FUNCTION binary_reverse_table () RETURNS TABLE (c1 int) AS $$ select * from binary_reverse $$;"#, diff --git a/src/query/ast/tests/it/testdata/stmt.txt b/src/query/ast/tests/it/testdata/stmt.txt index cb30d79c9a424..c3f557f623e1c 100644 --- a/src/query/ast/tests/it/testdata/stmt.txt +++ b/src/query/ast/tests/it/testdata/stmt.txt @@ -25854,6 +25854,46 @@ CreateUDF( ) +---------- Input ---------- +CREATE OR REPLACE FUNCTION isnotempty_test_replace (p STRING) RETURNS BOOL AS $$ not(is_null(p)) $$; +---------- Output --------- +CREATE OR REPLACE FUNCTION isnotempty_test_replace (p STRING) RETURNS BOOLEAN AS $$ +not(is_null(p)) +$$ +---------- AST ------------ +CreateUDF( + CreateUDFStmt { + create_option: CreateOrReplace, + udf_name: Identifier { + span: Some( + 27..50, + ), + name: "isnotempty_test_replace", + quote: None, + ident_type: None, + }, + description: None, + definition: ScalarUDF { + arg_types: [ + ( + Identifier { + span: Some( + 52..53, + ), + name: "p", + quote: None, + ident_type: None, + }, + String, + ), + ], + definition: "not(is_null(p))", + return_type: Boolean, + }, + }, +) + + ---------- Input ---------- CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815'; ---------- Output ---------