Skip to content

Commit 61d285d

Browse files
committed
feat: Implement UDTF
1 parent 473202d commit 61d285d

File tree

11 files changed

+255
-6
lines changed

11 files changed

+255
-6
lines changed

src/meta/app/src/principal/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ pub use task::Status;
9696
pub use task::Task;
9797
pub use task::TaskMessage;
9898
pub use task::TaskRun;
99+
pub use user_defined_function::UDTF;
99100
pub use task::WarehouseOptions;
100101
pub use task_ident::TaskIdent;
101102
pub use task_ident::TaskIdentRaw;

src/meta/app/src/principal/user_defined_function.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ pub struct UDFScript {
5151
pub immutable: Option<bool>,
5252
}
5353

54+
#[derive(Clone, Debug, Eq, PartialEq)]
55+
pub struct UDTF {
56+
pub arg_types: BTreeMap<String, DataType>,
57+
pub return_types: BTreeMap<String, DataType>,
58+
pub sql: String,
59+
}
60+
5461
#[derive(Clone, Debug, Eq, PartialEq)]
5562
pub struct UDAFScript {
5663
pub code: String,
@@ -71,6 +78,7 @@ pub enum UDFDefinition {
7178
UDFServer(UDFServer),
7279
UDFScript(UDFScript),
7380
UDAFScript(UDAFScript),
81+
UDTF(UDTF),
7482
}
7583

7684
impl UDFDefinition {
@@ -80,6 +88,7 @@ impl UDFDefinition {
8088
Self::UDFServer(_) => "UDFServer",
8189
Self::UDFScript(_) => "UDFScript",
8290
Self::UDAFScript(_) => "UDAFScript",
91+
Self::UDTF(_) => "UDTF",
8392
}
8493
}
8594

@@ -88,13 +97,15 @@ impl UDFDefinition {
8897
Self::LambdaUDF(_) => false,
8998
Self::UDFServer(_) => false,
9099
Self::UDFScript(_) => false,
100+
Self::UDTF(_) => false,
91101
Self::UDAFScript(_) => true,
92102
}
93103
}
94104

95105
pub fn language(&self) -> &str {
96106
match self {
97107
Self::LambdaUDF(_) => "SQL",
108+
Self::UDTF(_) => "SQL",
98109
Self::UDFServer(x) => x.language.as_str(),
99110
Self::UDFScript(x) => x.language.as_str(),
100111
Self::UDAFScript(x) => x.language.as_str(),
@@ -292,6 +303,26 @@ impl Display for UDFDefinition {
292303
}
293304
write!(f, " }} RETURNS {return_type} LANGUAGE {language} IMPORTS = {imports:?} PACKAGES = {packages:?} RUNTIME_VERSION = {runtime_version} AS $${code}$$")?;
294305
}
306+
UDFDefinition::UDTF(UDTF {
307+
arg_types,
308+
return_types,
309+
sql,
310+
}) => {
311+
for (i, (name, ty)) in arg_types.iter().enumerate() {
312+
if i > 0 {
313+
write!(f, ", ")?;
314+
}
315+
write!(f, "{name} {ty}")?;
316+
}
317+
write!(f, ") RETURNS (")?;
318+
for (i, (name, ty)) in return_types.iter().enumerate() {
319+
if i > 0 {
320+
write!(f, ", ")?;
321+
}
322+
write!(f, "{name} {ty}")?;
323+
}
324+
write!(f, ") AS $${sql}$$")?;
325+
}
295326
}
296327
Ok(())
297328
}

src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::BTreeMap;
1516
use chrono::DateTime;
1617
use chrono::Utc;
1718
use databend_common_expression::infer_schema_type;
@@ -275,6 +276,87 @@ impl FromToProto for mt::UDAFScript {
275276
}
276277
}
277278

279+
impl FromToProto for mt::UDTF {
280+
type PB = pb::Udtf;
281+
282+
fn get_pb_ver(p: &Self::PB) -> u64 {
283+
p.ver
284+
}
285+
286+
fn from_pb(p: Self::PB) -> Result<Self, Incompatible>
287+
where
288+
Self: Sized
289+
{
290+
reader_check_msg(p.ver, p.min_reader_ver)?;
291+
292+
let mut arg_types = BTreeMap::new();
293+
for (arg_name, arg_ty) in p
294+
.arg_names
295+
.into_iter().zip(p.arg_types.into_iter()) {
296+
let ty = (&TableDataType::from_pb(arg_ty)?).into();
297+
298+
arg_types.insert(arg_name, ty);
299+
}
300+
301+
let mut return_types = BTreeMap::new();
302+
for (return_name, arg_ty) in p
303+
.return_names
304+
.into_iter().zip(p.return_types.into_iter()) {
305+
let ty = (&TableDataType::from_pb(arg_ty)?).into();
306+
307+
return_types.insert(return_name, ty);
308+
}
309+
310+
Ok(Self {
311+
arg_types,
312+
return_types,
313+
sql: p.sql,
314+
})
315+
}
316+
317+
fn to_pb(&self) -> Result<Self::PB, Incompatible> {
318+
let mut arg_names = Vec::with_capacity(self.arg_types.len());
319+
let mut arg_types = Vec::with_capacity(self.arg_types.len());
320+
for (arg_name, arg_type) in self.arg_types.iter() {
321+
let arg_type = infer_schema_type(arg_type)
322+
.map_err(|e| {
323+
Incompatible::new(format!(
324+
"Convert DataType to TableDataType failed: {}",
325+
e.message()
326+
))
327+
})?
328+
.to_pb()?;
329+
arg_names.push(arg_name.clone());
330+
arg_types.push(arg_type);
331+
}
332+
333+
let mut return_names = Vec::with_capacity(self.return_types.len());
334+
let mut return_types = Vec::with_capacity(self.return_types.len());
335+
for (return_name, return_type) in self.return_types.iter() {
336+
let return_type = infer_schema_type(return_type)
337+
.map_err(|e| {
338+
Incompatible::new(format!(
339+
"Convert DataType to TableDataType failed: {}",
340+
e.message()
341+
))
342+
})?
343+
.to_pb()?;
344+
return_names.push(return_name.clone());
345+
return_types.push(return_type);
346+
}
347+
348+
Ok(pb::Udtf {
349+
ver: VER,
350+
min_reader_ver: MIN_READER_VER,
351+
arg_names,
352+
arg_types,
353+
return_names,
354+
return_types,
355+
sql: self.sql.clone(),
356+
})
357+
}
358+
}
359+
278360
impl FromToProto for mt::UserDefinedFunction {
279361
type PB = pb::UserDefinedFunction;
280362
fn get_pb_ver(p: &Self::PB) -> u64 {
@@ -295,6 +377,9 @@ impl FromToProto for mt::UserDefinedFunction {
295377
Some(pb::user_defined_function::Definition::UdafScript(udaf_script)) => {
296378
mt::UDFDefinition::UDAFScript(mt::UDAFScript::from_pb(udaf_script)?)
297379
}
380+
Some(pb::user_defined_function::Definition::Udtf(udtf)) => {
381+
mt::UDFDefinition::UDTF(mt::UDTF::from_pb(udtf)?)
382+
}
298383
None => {
299384
return Err(Incompatible::new(
300385
"UserDefinedFunction.definition cannot be None".to_string(),
@@ -327,6 +412,9 @@ impl FromToProto for mt::UserDefinedFunction {
327412
mt::UDFDefinition::UDAFScript(udaf_script) => {
328413
pb::user_defined_function::Definition::UdafScript(udaf_script.to_pb()?)
329414
}
415+
mt::UDFDefinition::UDTF(udtf) => {
416+
pb::user_defined_function::Definition::Udtf(udtf.to_pb()?)
417+
}
330418
};
331419

332420
Ok(pb::UserDefinedFunction {

src/meta/protos/proto/udf.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ message UDAFScript {
6969
repeated string packages = 8;
7070
}
7171

72+
message UDTF {
73+
uint64 ver = 100;
74+
uint64 min_reader_ver = 101;
75+
76+
repeated string arg_names = 1;
77+
repeated DataType arg_types = 2;
78+
repeated string return_names = 3;
79+
repeated DataType return_types = 4;
80+
string sql = 5;
81+
}
82+
7283
message UserDefinedFunction {
7384
uint64 ver = 100;
7485
uint64 min_reader_ver = 101;
@@ -80,6 +91,7 @@ message UserDefinedFunction {
8091
UDFServer udf_server = 4;
8192
UDFScript udf_script = 6;
8293
UDAFScript udaf_script = 7;
94+
UDTF udtf = 8;
8395
}
8496
// The time udf created.
8597
optional string created_on = 5;

src/query/ast/src/ast/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::ast::WithOptions;
2525
use crate::Span;
2626

2727
// Identifier of table name or column name.
28-
#[derive(Debug, Clone, PartialEq, Eq, Drive, DriveMut)]
28+
#[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Drive, DriveMut)]
2929
pub struct Identifier {
3030
pub span: Span,
3131
pub name: String,
@@ -34,7 +34,7 @@ pub struct Identifier {
3434
pub ident_type: IdentifierType,
3535
}
3636

37-
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
37+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Default)]
3838
pub enum IdentifierType {
3939
#[default]
4040
None,

src/query/ast/src/ast/statements/udf.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ pub enum UDFDefinition {
7171
language: String,
7272
runtime_version: String,
7373
},
74+
UDTFSql {
75+
arg_types: BTreeMap<Identifier, TypeName>,
76+
return_types: BTreeMap<Identifier, TypeName>,
77+
sql: String,
78+
}
7479
}
7580

7681
impl Display for UDFDefinition {
@@ -176,6 +181,17 @@ impl Display for UDFDefinition {
176181
}
177182
write!(f, " ADDRESS = '{address}'")?;
178183
}
184+
UDFDefinition::UDTFSql {
185+
arg_types,
186+
return_types,
187+
sql
188+
} => {
189+
write!(f, "(")?;
190+
write_comma_separated_list(f, arg_types.iter().map(|(name, ty)| format!("{name} {ty}")))?;
191+
write!(f, ") RETURNS TABLE (")?;
192+
write_comma_separated_list(f, return_types.iter().map(|(name, ty)| format!("{name} {ty}")))?;
193+
write!(f, ") AS $$\n{sql}\n$$")?;
194+
}
179195
UDFDefinition::UDAFScript {
180196
arg_types,
181197
state_fields: state_types,

src/query/ast/src/parser/statement.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4969,6 +4969,21 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
49694969
},
49704970
);
49714971

4972+
let udtf = map(
4973+
rule! {
4974+
"(" ~ #comma_separated_list0(udtf_arg) ~ ")"
4975+
~ RETURNS ~ TABLE ~ "(" ~ #comma_separated_list0(udtf_arg) ~ ")"
4976+
~ AS ~ ^#code_string
4977+
},
4978+
|(_, arg_types, _, _, _, _, return_types, _, _, sql)| {
4979+
UDFDefinition::UDTFSql {
4980+
arg_types: BTreeMap::from_iter(arg_types),
4981+
return_types: BTreeMap::from_iter(return_types),
4982+
sql,
4983+
}
4984+
}
4985+
);
4986+
49724987
let udaf = map(
49734988
rule! {
49744989
"(" ~ #comma_separated_list0(type_name) ~ ")"
@@ -5033,7 +5048,14 @@ pub fn udf_definition(i: Input) -> IResult<UDFDefinition> {
50335048
#lambda_udf: "AS (<parameter>, ...) -> <definition expr>"
50345049
| #udaf: "(<arg_type>, ...) STATE {<state_field>, ...} RETURNS <return_type> LANGUAGE <language> { ADDRESS=<udf_server_address> | AS <language_codes> } "
50355050
| #udf: "(<arg_type>, ...) RETURNS <return_type> LANGUAGE <language> HANDLER=<handler> { ADDRESS=<udf_server_address> | AS <language_codes> } "
5051+
| #udtf: "(<arg_type>, ...) RETURNS TABLE (<return_type>, ...) AS <sql> }"
5052+
)(i)
5053+
}
50365054

5055+
fn udtf_arg(i:Input) -> IResult<(Identifier, TypeName)> {
5056+
map(
5057+
rule! { #ident ~ ^#type_name },
5058+
|(name, ty)| (name, ty),
50375059
)(i)
50385060
}
50395061

src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use databend_common_ast::ast::SelectTarget;
2525
use databend_common_ast::ast::TableAlias;
2626
use databend_common_ast::ast::TableReference;
2727
use databend_common_ast::Span;
28+
use databend_common_catalog::catalog::CatalogManager;
2829
use databend_common_catalog::catalog_kind::CATALOG_DEFAULT;
2930
use databend_common_catalog::table_args::TableArgs;
3031
use databend_common_catalog::table_function::TableFunction;
@@ -34,6 +35,7 @@ use databend_common_expression::types::NumberScalar;
3435
use databend_common_expression::FunctionKind;
3536
use databend_common_expression::Scalar;
3637
use databend_common_functions::BUILTIN_FUNCTIONS;
38+
use databend_common_meta_app::principal::UDFDefinition;
3739
use databend_common_storages_result_cache::ResultCacheMetaManager;
3840
use databend_common_storages_result_cache::ResultScan;
3941
use databend_common_users::UserApiProvider;
@@ -45,11 +47,11 @@ use crate::binder::ColumnBindingBuilder;
4547
use crate::binder::Visibility;
4648
use crate::optimizer::ir::SExpr;
4749
use crate::planner::semantic::normalize_identifier;
48-
use crate::plans::EvalScalar;
50+
use crate::plans::{EvalScalar, Plan};
4951
use crate::plans::FunctionCall;
5052
use crate::plans::RelOperator;
5153
use crate::plans::ScalarItem;
52-
use crate::BindContext;
54+
use crate::{BindContext, Planner};
5355
use crate::ScalarExpr;
5456

5557
impl Binder {
@@ -129,6 +131,42 @@ impl Binder {
129131
);
130132
let table_args = bind_table_args(&mut scalar_binder, params, named_params)?;
131133

134+
let tenant = self.ctx.get_tenant();
135+
let udtf_result = databend_common_base::runtime::block_on(async {
136+
if let Some(UDFDefinition::UDTF(udtf)) = UserApiProvider::instance().get_udf(&tenant, &func_name.name).await?.map(|udf| udf.definition) {
137+
let mut sql = udtf.sql;
138+
139+
for (name, arg) in table_args.named.iter() {
140+
// FIXME: Parameter substitution
141+
let Some(ty) = udtf.arg_types.get(name) else { return Err(ErrorCode::InvalidArgument(format!("Function '{func_name}' does not have a parameter named '{name}'"))) };
142+
143+
sql = sql.replace(name, &arg.to_string());
144+
}
145+
let mut planner = Planner::new(self.ctx.clone());
146+
let (_, extras) = planner.plan_sql(&sql).await?;
147+
let binder = Binder::new(
148+
self.ctx.clone(),
149+
CatalogManager::instance(),
150+
self.name_resolution_ctx.clone(),
151+
self.metadata.clone(),
152+
)
153+
.with_subquery_executor(self.subquery_executor.clone());
154+
let plan = binder.bind(&extras.statement).await?;
155+
156+
let Plan::Query {
157+
s_expr, bind_context, ..
158+
} = plan else {
159+
return Err(ErrorCode::UDFRuntimeError("Query in UDTF returned no result set"))
160+
};
161+
162+
return Ok(Some((*s_expr, *bind_context)))
163+
}
164+
Ok(None)
165+
});
166+
if let Some(result) = udtf_result? {
167+
return Ok(result);
168+
}
169+
132170
if func_name.name.eq_ignore_ascii_case("result_scan") {
133171
self.bind_result_scan(bind_context, span, alias, &table_args)
134172
} else {

0 commit comments

Comments
 (0)