Skip to content

Commit bd4f2da

Browse files
committed
feat: implement Scalar UDF
1 parent ff22a5c commit bd4f2da

File tree

16 files changed

+452
-51
lines changed

16 files changed

+452
-51
lines changed

src/meta/app/src/principal/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ pub use user_auth::AuthType;
109109
pub use user_auth::PasswordHashMethod;
110110
pub use user_defined_file_format::UserDefinedFileFormat;
111111
pub use user_defined_function::LambdaUDF;
112+
pub use user_defined_function::ScalarUDF;
112113
pub use user_defined_function::UDAFScript;
113114
pub use user_defined_function::UDFDefinition;
114115
pub use user_defined_function::UDFScript;

src/meta/app/src/principal/user_defined_function.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ pub struct UDTF {
6464
pub sql: String,
6565
}
6666

67+
/// User Defined Scalar Function (ScalarUDF)
68+
///
69+
/// # Fields
70+
/// - `arg_types`: arg name with data type
71+
/// - `return_type`: return data type
72+
/// - `definition`: typically including the code or expression implementing the function logic
73+
#[derive(Clone, Debug, Eq, PartialEq)]
74+
pub struct ScalarUDF {
75+
pub arg_types: Vec<(String, DataType)>,
76+
pub return_type: DataType,
77+
pub definition: String,
78+
}
79+
6780
#[derive(Clone, Debug, Eq, PartialEq)]
6881
pub struct UDAFScript {
6982
pub code: String,
@@ -85,6 +98,7 @@ pub enum UDFDefinition {
8598
UDFScript(UDFScript),
8699
UDAFScript(UDAFScript),
87100
UDTF(UDTF),
101+
ScalarUDF(ScalarUDF),
88102
}
89103

90104
impl UDFDefinition {
@@ -95,6 +109,7 @@ impl UDFDefinition {
95109
Self::UDFScript(_) => "UDFScript",
96110
Self::UDAFScript(_) => "UDAFScript",
97111
Self::UDTF(_) => "UDTF",
112+
UDFDefinition::ScalarUDF(_) => "ScalarUDF",
98113
}
99114
}
100115

@@ -104,6 +119,7 @@ impl UDFDefinition {
104119
Self::UDFServer(_) => false,
105120
Self::UDFScript(_) => false,
106121
Self::UDTF(_) => false,
122+
Self::ScalarUDF(_) => false,
107123
Self::UDAFScript(_) => true,
108124
}
109125
}
@@ -112,6 +128,7 @@ impl UDFDefinition {
112128
match self {
113129
Self::LambdaUDF(_) => "SQL",
114130
Self::UDTF(_) => "SQL",
131+
Self::ScalarUDF(_) => "SQL",
115132
Self::UDFServer(x) => x.language.as_str(),
116133
Self::UDFScript(x) => x.language.as_str(),
117134
Self::UDAFScript(x) => x.language.as_str(),
@@ -329,6 +346,19 @@ impl Display for UDFDefinition {
329346
}
330347
write!(f, ") AS $${sql}$$")?;
331348
}
349+
UDFDefinition::ScalarUDF(ScalarUDF {
350+
arg_types,
351+
return_type,
352+
definition,
353+
}) => {
354+
for (i, (name, ty)) in arg_types.iter().enumerate() {
355+
if i > 0 {
356+
write!(f, ", ")?;
357+
}
358+
write!(f, "{name} {ty}")?;
359+
}
360+
write!(f, ") RETURNS {return_type} AS $${definition}$$")?;
361+
}
332362
}
333363
Ok(())
334364
}

src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,75 @@ impl FromToProto for mt::UDTF {
357357
}
358358
}
359359

360+
impl FromToProto for mt::ScalarUDF {
361+
type PB = pb::ScalarUdf;
362+
363+
fn get_pb_ver(p: &Self::PB) -> u64 {
364+
p.ver
365+
}
366+
367+
fn from_pb(p: Self::PB) -> Result<Self, Incompatible>
368+
where Self: Sized {
369+
reader_check_msg(p.ver, p.min_reader_ver)?;
370+
371+
let mut arg_types = Vec::new();
372+
for arg_ty in p.arg_types {
373+
let ty_pb = arg_ty.ty.ok_or_else(|| {
374+
Incompatible::new("ScalarUDF.arg_types.ty can not be None".to_string())
375+
})?;
376+
let ty = TableDataType::from_pb(ty_pb)?;
377+
378+
arg_types.push((arg_ty.name, (&ty).into()));
379+
}
380+
381+
let return_type_pb = p.return_type.ok_or_else(|| {
382+
Incompatible::new("ScalarUDF.return_type can not be None".to_string())
383+
})?;
384+
let return_type = TableDataType::from_pb(return_type_pb)?;
385+
386+
Ok(Self {
387+
arg_types,
388+
return_type: (&return_type).into(),
389+
definition: p.definition,
390+
})
391+
}
392+
393+
fn to_pb(&self) -> Result<Self::PB, Incompatible> {
394+
let mut arg_types = Vec::with_capacity(self.arg_types.len());
395+
for (arg_name, arg_type) in self.arg_types.iter() {
396+
let arg_type = infer_schema_type(arg_type)
397+
.map_err(|e| {
398+
Incompatible::new(format!(
399+
"Convert DataType to TableDataType failed: {}",
400+
e.message()
401+
))
402+
})?
403+
.to_pb()?;
404+
arg_types.push(UdtfArg {
405+
name: arg_name.clone(),
406+
ty: Some(arg_type),
407+
});
408+
}
409+
410+
let return_type = infer_schema_type(&self.return_type)
411+
.map_err(|e| {
412+
Incompatible::new(format!(
413+
"Convert DataType to TableDataType failed: {}",
414+
e.message()
415+
))
416+
})?
417+
.to_pb()?;
418+
419+
Ok(pb::ScalarUdf {
420+
ver: VER,
421+
min_reader_ver: MIN_READER_VER,
422+
arg_types,
423+
return_type: Some(return_type),
424+
definition: self.definition.clone(),
425+
})
426+
}
427+
}
428+
360429
impl FromToProto for mt::UserDefinedFunction {
361430
type PB = pb::UserDefinedFunction;
362431
fn get_pb_ver(p: &Self::PB) -> u64 {
@@ -380,6 +449,9 @@ impl FromToProto for mt::UserDefinedFunction {
380449
Some(pb::user_defined_function::Definition::Udtf(udtf)) => {
381450
mt::UDFDefinition::UDTF(mt::UDTF::from_pb(udtf)?)
382451
}
452+
Some(pb::user_defined_function::Definition::ScalarUdf(scalar_udf)) => {
453+
mt::UDFDefinition::ScalarUDF(mt::ScalarUDF::from_pb(scalar_udf)?)
454+
}
383455
None => {
384456
return Err(Incompatible::new(
385457
"UserDefinedFunction.definition cannot be None".to_string(),
@@ -415,6 +487,9 @@ impl FromToProto for mt::UserDefinedFunction {
415487
mt::UDFDefinition::UDTF(udtf) => {
416488
pb::user_defined_function::Definition::Udtf(udtf.to_pb()?)
417489
}
490+
mt::UDFDefinition::ScalarUDF(scalar_udf) => {
491+
pb::user_defined_function::Definition::ScalarUdf(scalar_udf.to_pb()?)
492+
}
418493
};
419494

420495
Ok(pb::UserDefinedFunction {

src/meta/proto-conv/src/util.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[
173173
(141, "2025-08-06: Add: row_access.proto"),
174174
(142, "2025-08-15: Add: table_meta add row_access_policy"),
175175
(143, "2025-08-18: Add: add UDTF"),
176+
(144, "2025-08-18: Add: add ScalarUDF"),
176177
// Dear developer:
177178
// If you're gonna add a new metadata version, you'll have to add a test for it.
178179
// You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`)

src/meta/proto-conv/tests/it/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,4 @@ mod v140_task_message;
135135
mod v141_row_access_policy;
136136
mod v142_table_row_access_policy;
137137
mod v143_udtf;
138+
mod v144_scalar_udf;
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use chrono::DateTime;
16+
use chrono::Utc;
17+
use databend_common_expression::types::DataType;
18+
use databend_common_meta_app::principal::ScalarUDF;
19+
use databend_common_meta_app::principal::UDFDefinition;
20+
use databend_common_meta_app::principal::UserDefinedFunction;
21+
use fastrace::func_name;
22+
23+
use crate::common;
24+
25+
// These bytes are built when a new version in introduced,
26+
// and are kept for backward compatibility test.
27+
//
28+
// *************************************************************
29+
// * These messages should never be updated, *
30+
// * only be added when a new version is added, *
31+
// * or be removed when an old version is no longer supported. *
32+
// *************************************************************
33+
//
34+
// The message bytes are built from the output of `test_pb_from_to()`
35+
#[test]
36+
fn test_decode_v144_scalar_udf() -> anyhow::Result<()> {
37+
let bytes = vec![
38+
10, 15, 116, 101, 115, 116, 95, 115, 99, 97, 108, 97, 114, 95, 117, 100, 102, 18, 21, 84,
39+
104, 105, 115, 32, 105, 115, 32, 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111,
40+
110, 74, 69, 10, 16, 10, 2, 99, 49, 18, 10, 146, 2, 0, 160, 6, 144, 1, 168, 6, 24, 10, 16,
41+
10, 2, 99, 50, 18, 10, 138, 2, 0, 160, 6, 144, 1, 168, 6, 24, 18, 10, 170, 2, 0, 160, 6,
42+
144, 1, 168, 6, 24, 26, 12, 67, 85, 82, 82, 69, 78, 84, 95, 68, 65, 84, 69, 160, 6, 144, 1,
43+
168, 6, 24, 42, 23, 50, 48, 50, 51, 45, 49, 50, 45, 49, 53, 32, 48, 49, 58, 50, 54, 58, 48,
44+
57, 32, 85, 84, 67, 160, 6, 144, 1, 168, 6, 24,
45+
];
46+
47+
let want = || UserDefinedFunction {
48+
name: "test_scalar_udf".to_string(),
49+
description: "This is a description".to_string(),
50+
definition: UDFDefinition::ScalarUDF(ScalarUDF {
51+
arg_types: vec![(s("c1"), DataType::String), (s("c2"), DataType::Boolean)],
52+
return_type: DataType::Date,
53+
definition: "CURRENT_DATE".to_string(),
54+
}),
55+
created_on: DateTime::<Utc>::from_timestamp(1702603569, 0).unwrap(),
56+
};
57+
58+
common::test_pb_from_to(func_name!(), want())?;
59+
common::test_load_old(func_name!(), bytes.as_slice(), 144, want())
60+
}
61+
62+
fn s(ss: impl ToString) -> String {
63+
ss.to_string()
64+
}

src/meta/protos/proto/udf.proto

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ message UDTF {
8585
string sql = 3;
8686
}
8787

88+
message ScalarUDF {
89+
uint64 ver = 100;
90+
uint64 min_reader_ver = 101;
91+
92+
// arg name with data type
93+
repeated UDTFArg arg_types = 1;
94+
// return data type
95+
DataType return_type = 2;
96+
// typically including the code or expression implementing the function logic
97+
string definition = 3;
98+
}
99+
100+
88101
message UserDefinedFunction {
89102
uint64 ver = 100;
90103
uint64 min_reader_ver = 101;
@@ -97,6 +110,7 @@ message UserDefinedFunction {
97110
UDFScript udf_script = 6;
98111
UDAFScript udaf_script = 7;
99112
UDTF udtf = 8;
113+
ScalarUDF scalar_udf = 9;
100114
}
101115
// The time udf created.
102116
optional string created_on = 5;

src/query/ast/src/ast/statements/udf.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ pub enum UDFDefinition {
7676
return_types: Vec<(Identifier, TypeName)>,
7777
sql: String,
7878
},
79+
ScalarUDF {
80+
arg_types: Vec<(Identifier, TypeName)>,
81+
definition: String,
82+
return_type: TypeName,
83+
},
7984
}
8085

8186
impl Display for UDFDefinition {
@@ -198,6 +203,18 @@ impl Display for UDFDefinition {
198203
)?;
199204
write!(f, ") AS $$\n{sql}\n$$")?;
200205
}
206+
UDFDefinition::ScalarUDF {
207+
arg_types,
208+
definition,
209+
return_type,
210+
} => {
211+
write!(f, "(")?;
212+
write_comma_separated_list(
213+
f,
214+
arg_types.iter().map(|(name, ty)| format!("{name} {ty}")),
215+
)?;
216+
write!(f, ") RETURNS {return_type} AS $$\n{definition}\n$$")?;
217+
}
201218
UDFDefinition::UDAFScript {
202219
arg_types,
203220
state_fields: state_types,

src/query/ast/src/parser/statement.rs

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4975,6 +4975,31 @@ pub fn udf_script_or_address(i: Input) -> IResult<(String, bool)> {
49754975
}
49764976

49774977
pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
4978+
enum ReturnBody {
4979+
Scalar(TypeName),
4980+
Table(Vec<(Identifier, TypeName)>),
4981+
}
4982+
4983+
fn return_body(i: Input) -> IResult<ReturnBody> {
4984+
let scalar = map(
4985+
rule! {
4986+
#type_name
4987+
},
4988+
ReturnBody::Scalar,
4989+
);
4990+
let table = map(
4991+
rule! {
4992+
TABLE ~ "(" ~ #comma_separated_list0(udtf_arg) ~ ")"
4993+
},
4994+
|(_, _, arg_types, _)| ReturnBody::Table(arg_types),
4995+
);
4996+
4997+
rule!(
4998+
#scalar: "<return_type>"
4999+
| #table: "TABLE (<return_type>, ...)"
5000+
)(i)
5001+
}
5002+
49785003
let lambda_udf = map(
49795004
rule! {
49805005
AS ~ "(" ~ #comma_separated_list0(ident) ~ ")"
@@ -5049,16 +5074,23 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
50495074
},
50505075
);
50515076

5052-
let udtf = map(
5077+
let scalar_udf_or_udtf = map(
50535078
rule! {
50545079
"(" ~ #comma_separated_list0(udtf_arg) ~ ")"
5055-
~ RETURNS ~ TABLE ~ "(" ~ #comma_separated_list0(udtf_arg) ~ ")"
5080+
~ RETURNS ~ ^#return_body
50565081
~ AS ~ ^#code_string
50575082
},
5058-
|(_, arg_types, _, _, _, _, return_types, _, _, sql)| UDFDefinition::UDTFSql {
5059-
arg_types,
5060-
return_types,
5061-
sql,
5083+
|(_, arg_types, _, _, return_body, _, sql)| match return_body {
5084+
ReturnBody::Scalar(return_type) => UDFDefinition::ScalarUDF {
5085+
arg_types,
5086+
definition: sql,
5087+
return_type,
5088+
},
5089+
ReturnBody::Table(return_types) => UDFDefinition::UDTFSql {
5090+
arg_types,
5091+
return_types,
5092+
sql,
5093+
},
50625094
},
50635095
);
50645096

@@ -5126,7 +5158,7 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
51265158
#lambda_udf: "AS (<parameter>, ...) -> <definition expr>"
51275159
| #udaf: "(<arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> } "
51285160
| #udf: "(<arg_type>, ...) RETURNS <return_type> LANGUAGE <language> HANDLER=<handler> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5129-
| #udtf: "(<arg_type>, ...) RETURNS TABLE (<return_type>, ...) AS <sql> }"
5161+
| #scalar_udf_or_udtf: "(<arg_type>, ...) RETURNS <return body> AS <sql> }"
51305162
)(i)
51315163
}
51325164

0 commit comments

Comments
 (0)