Skip to content

Commit 9e540ae

Browse files
authored
Merge pull request #145 from influxdata/crepererum/fix-udf-explain
fix: UDF extraction and `EXPLAIN` queries
2 parents 6aa8e06 + 13607b1 commit 9e540ae

File tree

2 files changed

+70
-31
lines changed

2 files changed

+70
-31
lines changed

host/src/udf_query.rs

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,7 @@ impl<'a> UdfQueryParser<'a> {
8989
let mut sql = String::new();
9090
let mut udf_blocks: HashMap<String, Vec<String>> = HashMap::new();
9191
for s in statements {
92-
let Statement::Statement(stmt) = s else {
93-
continue;
94-
};
95-
96-
match parse_udf(*stmt)? {
92+
match parse_udf(s)? {
9793
Parsed::Udf { code, language } => {
9894
if let Some(existing) = udf_blocks.get_mut(&language) {
9995
existing.push(code);
@@ -130,31 +126,34 @@ enum Parsed {
130126
}
131127

132128
/// Parse a single SQL statement to extract a UDF
133-
fn parse_udf(stmt: SqlStatement) -> DataFusionResult<Parsed> {
129+
fn parse_udf(stmt: Statement) -> DataFusionResult<Parsed> {
134130
match stmt {
135-
SqlStatement::CreateFunction(cf) => {
136-
let function_body = cf.function_body.as_ref();
137-
138-
let language = if let Some(lang) = cf.language.as_ref() {
139-
lang.to_string()
140-
} else {
141-
return Err(DataFusionError::Plan(
142-
"function language is required for UDFs".to_string(),
143-
));
144-
};
145-
146-
let code = match function_body {
147-
Some(body) => extract_function_body(body),
148-
None => Err(DataFusionError::Plan(
149-
"function body is required for UDFs".to_string(),
150-
)),
151-
}?;
152-
153-
Ok(Parsed::Udf {
154-
code: code.to_string(),
155-
language,
156-
})
157-
}
131+
Statement::Statement(stmt) => match *stmt {
132+
SqlStatement::CreateFunction(cf) => {
133+
let function_body = cf.function_body.as_ref();
134+
135+
let language = if let Some(lang) = cf.language.as_ref() {
136+
lang.to_string()
137+
} else {
138+
return Err(DataFusionError::Plan(
139+
"function language is required for UDFs".to_string(),
140+
));
141+
};
142+
143+
let code = match function_body {
144+
Some(body) => extract_function_body(body),
145+
None => Err(DataFusionError::Plan(
146+
"function body is required for UDFs".to_string(),
147+
)),
148+
}?;
149+
150+
Ok(Parsed::Udf {
151+
code: code.to_string(),
152+
language,
153+
})
154+
}
155+
_ => Ok(Parsed::Other(stmt.to_string())),
156+
},
158157
_ => Ok(Parsed::Other(stmt.to_string())),
159158
}
160159
}

host/tests/integration_tests/udf_query.rs

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use datafusion::{
44
assert_batches_eq,
55
prelude::{DataFrame, SessionContext},
66
};
7-
use datafusion_common::Result as DataFusionResult;
7+
use datafusion_common::{Result as DataFusionResult, test_util::batches_to_string};
88
use datafusion_udf_wasm_host::{
99
WasmPermissions,
1010
udf_query::{ParsedQuery, UdfQueryParser},
@@ -77,7 +77,7 @@ def add_one(x: int) -> int:
7777
';
7878
7979
CREATE FUNCTION multiply_two()
80-
LANGUAGE python
80+
LANGUAGE python
8181
AS '
8282
def multiply_two(x: int) -> int:
8383
return x * 2
@@ -175,3 +175,43 @@ SELECT add_one(1)
175175
let err = r.err().unwrap();
176176
assert!(err.message().contains("Invalid function 'add_one'"));
177177
}
178+
179+
#[tokio::test(flavor = "multi_thread")]
180+
async fn test_explain() {
181+
let query = r#"
182+
CREATE FUNCTION add_one()
183+
LANGUAGE python
184+
AS '
185+
def add_one(x: int) -> int:
186+
return x + 1
187+
';
188+
189+
EXPLAIN SELECT add_one(1);
190+
"#;
191+
192+
let ctx = SessionContext::new();
193+
let component = python_component().await;
194+
195+
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
196+
let parsed_query = parser
197+
.parse(query, &WasmPermissions::new(), ctx.task_ctx().as_ref())
198+
.await
199+
.unwrap();
200+
201+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
202+
let batch = df.collect().await.unwrap();
203+
204+
insta::assert_snapshot!(
205+
batches_to_string(&batch),
206+
@r"
207+
+---------------+--------------------------------------------------------+
208+
| plan_type | plan |
209+
+---------------+--------------------------------------------------------+
210+
| logical_plan | Projection: add_one(Int64(1)) |
211+
| | EmptyRelation |
212+
| physical_plan | ProjectionExec: expr=[add_one(1) as add_one(Int64(1))] |
213+
| | PlaceholderRowExec |
214+
| | |
215+
+---------------+--------------------------------------------------------+
216+
");
217+
}

0 commit comments

Comments
 (0)