Skip to content

Commit 97e1f66

Browse files
committed
fix: nitpicks
1 parent 04e1139 commit 97e1f66

File tree

2 files changed

+32
-42
lines changed

2 files changed

+32
-42
lines changed

host/src/udf_query.rs

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Embedded SQL approach for executing Python UDFs within SQL queries.
1+
//! Embedded SQL approach for executing UDFs within SQL queries.
22
33
use std::collections::HashMap;
44

@@ -38,29 +38,33 @@ impl std::fmt::Debug for UdfQueryParser<'_> {
3838

3939
impl<'a> UdfQueryParser<'a> {
4040
/// Registers the UDF query in DataFusion.
41-
pub async fn new(
42-
components: HashMap<String, &'a WasmComponentPrecompiled>,
43-
) -> DataFusionResult<Self> {
44-
Ok(Self { components })
41+
pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self {
42+
Self { components }
4543
}
4644

47-
/// Parses a SQL query that defines & uses Python UDFs into a [ParsedQuery].
45+
/// Parses a SQL query that defines & uses UDFs into a [ParsedQuery].
4846
pub async fn parse(
4947
&self,
5048
udf_query: &str,
5149
permissions: &WasmPermissions,
5250
task_ctx: &TaskContext,
5351
) -> DataFusionResult<ParsedQuery> {
54-
let (code, sql, lang) = self.parse_inner(udf_query, task_ctx)?;
55-
56-
let component = self.components.get(&lang).ok_or_else(|| {
57-
DataFusionError::Plan(format!(
58-
"no WASM component registered for language: {:?}",
59-
lang
60-
))
61-
})?;
52+
let (code, sql) = self.parse_inner(udf_query, task_ctx)?;
53+
54+
let mut udfs = vec![];
55+
for (lang, blocks) in code {
56+
let component = self.components.get(&lang).ok_or_else(|| {
57+
DataFusionError::Plan(format!(
58+
"no WASM component registered for language: {:?}",
59+
lang
60+
))
61+
})?;
62+
63+
for code in blocks {
64+
udfs.extend(WasmScalarUdf::new(component, permissions, code).await?);
65+
}
66+
}
6267

63-
let udfs = WasmScalarUdf::new(component, permissions, code).await?;
6468
Ok(ParsedQuery { udfs, sql })
6569
}
6670

@@ -70,7 +74,7 @@ impl<'a> UdfQueryParser<'a> {
7074
&self,
7175
query: &str,
7276
task_ctx: &TaskContext,
73-
) -> DataFusionResult<(String, String, String)> {
77+
) -> DataFusionResult<(HashMap<String, Vec<String>>, String)> {
7478
let options = task_ctx.session_config().options();
7579

7680
let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect");
@@ -82,20 +86,20 @@ impl<'a> UdfQueryParser<'a> {
8286
.build()?
8387
.parse_statements()?;
8488

85-
let mut udf_code = String::new();
8689
let mut sql = String::new();
87-
88-
let mut udf_language = String::new();
90+
let mut udf_blocks: HashMap<String, Vec<String>> = HashMap::new();
8991
for s in statements {
9092
let Statement::Statement(stmt) = s else {
9193
continue;
9294
};
9395

9496
match parse_udf(*stmt)? {
9597
Parsed::Udf { code, language } => {
96-
udf_language = language;
97-
udf_code.push_str(&code);
98-
udf_code.push('\n');
98+
if let Some(existing) = udf_blocks.get_mut(&language) {
99+
existing.push(code);
100+
} else {
101+
udf_blocks.insert(language.clone(), vec![code]);
102+
}
99103
}
100104
Parsed::Other(statement) => {
101105
sql.push_str(&statement);
@@ -104,17 +108,11 @@ impl<'a> UdfQueryParser<'a> {
104108
}
105109
}
106110

107-
if udf_code.is_empty() {
108-
return Err(DataFusionError::Plan(
109-
"UDF not defined in query".to_string(),
110-
));
111-
}
112-
113111
if sql.is_empty() {
114112
return Err(DataFusionError::Plan("no SQL query found".to_string()));
115113
}
116114

117-
Ok((udf_code, sql, udf_language))
115+
Ok((udf_blocks, sql))
118116
}
119117
}
120118

@@ -168,7 +166,7 @@ fn extract_function_body(body: &CreateFunctionBody) -> DataFusionResult<&str> {
168166
expression_into_str(e)
169167
}
170168
CreateFunctionBody::Return(_) => Err(DataFusionError::Plan(
171-
"`RETURN` function body not supported for Python UDFs".to_string(),
169+
"`RETURN` function body not supported for UDFs".to_string(),
172170
)),
173171
}
174172
}

host/tests/integration_tests/udf_query.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ SELECT add_one(1);
4545
let ctx = SessionContext::new();
4646
let component = python_component().await;
4747

48-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]))
49-
.await
50-
.unwrap();
48+
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
5149
let parsed_query = parser
5250
.parse(query, &WasmPermissions::new(), ctx.task_ctx().as_ref())
5351
.await
@@ -91,9 +89,7 @@ SELECT add_one(1), multiply_two(3);
9189
let ctx = SessionContext::new();
9290
let component = python_component().await;
9391

94-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]))
95-
.await
96-
.unwrap();
92+
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
9793
let parsed_query = parser
9894
.parse(query, &WasmPermissions::new(), ctx.task_ctx().as_ref())
9995
.await
@@ -133,9 +129,7 @@ SELECT add_one(1), multiply_two(3);
133129
let ctx = SessionContext::new();
134130
let component = python_component().await;
135131

136-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]))
137-
.await
138-
.unwrap();
132+
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
139133
let parsed_query = parser
140134
.parse(query, &WasmPermissions::new(), ctx.task_ctx().as_ref())
141135
.await
@@ -169,9 +163,7 @@ SELECT add_one(1)
169163
let ctx = SessionContext::new();
170164
let component = python_component().await;
171165

172-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]))
173-
.await
174-
.unwrap();
166+
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
175167
let parsed_query = parser
176168
.parse(query, &WasmPermissions::new(), ctx.task_ctx().as_ref())
177169
.await

0 commit comments

Comments
 (0)