Skip to content

Commit f749932

Browse files
authored
Merge pull request #97 from influxdata/tm/stateless-udf
feat: implement stateless udf registration
2 parents f28d004 + 97e1f66 commit f749932

File tree

8 files changed

+1017
-38
lines changed

8 files changed

+1017
-38
lines changed

Cargo.lock

Lines changed: 644 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ resolver = "3"
1818
anyhow = { version = "1.0.100", default-features = false }
1919
arrow = { version = "55.2.0", default-features = false, features = ["ipc"] }
2020
chrono = { version = "0.4.42", default-features = false }
21+
datafusion = { version = "49.0.1", default-features = false }
2122
datafusion-common = { version = "49.0.1", default-features = false }
2223
datafusion-expr = { version = "49.0.1", default-features = false }
24+
datafusion-sql = { version = "49.0.1", default-features = false }
2325
datafusion-udf-wasm-arrow2bytes = { path = "arrow2bytes", version = "0.1.0" }
2426
datafusion-udf-wasm-bundle = { path = "guests/bundle", version = "0.1.0" }
2527
datafusion-udf-wasm-guest = { path = "guests/rust", version = "0.1.0" }
2628
datafusion-udf-wasm-python = { path = "guests/python", version = "0.1.0" }
29+
sqlparser = { version = "0.55.0", default-features = false, features = ["std", "visitor"] }
2730
http = { version = "1.3.1", default-features = false }
2831
hyper = { version = "1.7", default-features = false }
2932
tokio = { version = "1.48.0", default-features = false }
@@ -65,8 +68,10 @@ private_intra_doc_links = "deny"
6568
[patch.crates-io]
6669
# use same DataFusion fork as InfluxDB
6770
# See https://github.com/influxdata/arrow-datafusion/pull/72
71+
datafusion = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
6872
datafusion-common = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
6973
datafusion-expr = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
74+
datafusion-sql = { git = "https://github.com/influxdata/arrow-datafusion.git", rev = "8347a71f62d4fef8d37548f22b93877170039357" }
7075

7176
# faster tests
7277
[profile.dev.package]

host/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ workspace = true
1111
[dependencies]
1212
anyhow.workspace = true
1313
arrow.workspace = true
14+
datafusion.workspace = true
1415
datafusion-common.workspace = true
1516
datafusion-expr.workspace = true
1617
datafusion-udf-wasm-arrow2bytes.workspace = true
18+
datafusion-sql.workspace = true
19+
sqlparser.workspace = true
1720
http.workspace = true
1821
hyper.workspace = true
1922
rand = { version = "0.9" }

host/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mod error;
4343
pub mod http;
4444
mod linker;
4545
mod tokio_helpers;
46+
pub mod udf_query;
4647
mod vfs;
4748

4849
/// State of the WASM payload.

host/src/udf_query.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//! Embedded SQL approach for executing UDFs within SQL queries.
2+
3+
use std::collections::HashMap;
4+
5+
use datafusion::execution::TaskContext;
6+
use datafusion_common::{DataFusionError, Result as DataFusionResult};
7+
use datafusion_sql::parser::{DFParserBuilder, Statement};
8+
use sqlparser::ast::{CreateFunctionBody, Expr, Statement as SqlStatement, Value};
9+
use sqlparser::dialect::dialect_from_str;
10+
11+
use crate::{WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf};
12+
13+
/// A [ParsedQuery] contains the extracted UDFs and SQL query string
14+
#[derive(Debug)]
15+
pub struct ParsedQuery {
16+
/// Extracted UDFs from the query
17+
pub udfs: Vec<WasmScalarUdf>,
18+
/// SQL query string with UDF definitions removed
19+
pub sql: String,
20+
}
21+
22+
/// Handles the registration and invocation of UDF queries in DataFusion with a
23+
/// pre-compiled WASM component.
24+
pub struct UdfQueryParser<'a> {
25+
/// Pre-compiled WASM component.
26+
/// Necessary to create UDFs.
27+
components: HashMap<String, &'a WasmComponentPrecompiled>,
28+
}
29+
30+
impl std::fmt::Debug for UdfQueryParser<'_> {
31+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32+
f.debug_struct("UdfQueryParser")
33+
.field("session_ctx", &"SessionContext { ... }")
34+
.field("components", &self.components)
35+
.finish()
36+
}
37+
}
38+
39+
impl<'a> UdfQueryParser<'a> {
40+
/// Registers the UDF query in DataFusion.
41+
pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self {
42+
Self { components }
43+
}
44+
45+
/// Parses a SQL query that defines & uses UDFs into a [ParsedQuery].
46+
pub async fn parse(
47+
&self,
48+
udf_query: &str,
49+
permissions: &WasmPermissions,
50+
task_ctx: &TaskContext,
51+
) -> DataFusionResult<ParsedQuery> {
52+
let (code, sql) = self.parse_inner(udf_query, task_ctx)?;
53+
54+
let mut udfs = vec![];
55+
for (lang, blocks) in code {
56+
let component = self.components.get(&lang).ok_or_else(|| {
57+
DataFusionError::Plan(format!(
58+
"no WASM component registered for language: {:?}",
59+
lang
60+
))
61+
})?;
62+
63+
for code in blocks {
64+
udfs.extend(WasmScalarUdf::new(component, permissions, code).await?);
65+
}
66+
}
67+
68+
Ok(ParsedQuery { udfs, sql })
69+
}
70+
71+
/// Parse the combined query to extract the chosen UDF language, UDF
72+
/// definitions, and SQL statements.
73+
fn parse_inner(
74+
&self,
75+
query: &str,
76+
task_ctx: &TaskContext,
77+
) -> DataFusionResult<(HashMap<String, Vec<String>>, String)> {
78+
let options = task_ctx.session_config().options();
79+
80+
let dialect = dialect_from_str(options.sql_parser.dialect.clone()).expect("valid dialect");
81+
let recursion_limit = options.sql_parser.recursion_limit;
82+
83+
let statements = DFParserBuilder::new(query)
84+
.with_dialect(dialect.as_ref())
85+
.with_recursion_limit(recursion_limit)
86+
.build()?
87+
.parse_statements()?;
88+
89+
let mut sql = String::new();
90+
let mut udf_blocks: HashMap<String, Vec<String>> = HashMap::new();
91+
for s in statements {
92+
let Statement::Statement(stmt) = s else {
93+
continue;
94+
};
95+
96+
match parse_udf(*stmt)? {
97+
Parsed::Udf { code, language } => {
98+
if let Some(existing) = udf_blocks.get_mut(&language) {
99+
existing.push(code);
100+
} else {
101+
udf_blocks.insert(language.clone(), vec![code]);
102+
}
103+
}
104+
Parsed::Other(statement) => {
105+
sql.push_str(&statement);
106+
sql.push_str(";\n");
107+
}
108+
}
109+
}
110+
111+
if sql.is_empty() {
112+
return Err(DataFusionError::Plan("no SQL query found".to_string()));
113+
}
114+
115+
Ok((udf_blocks, sql))
116+
}
117+
}
118+
119+
/// Represents a parsed SQL statement
120+
enum Parsed {
121+
/// A UDF definition
122+
Udf {
123+
/// UDF code
124+
code: String,
125+
/// UDF language
126+
language: String,
127+
},
128+
/// Any other SQL statement
129+
Other(String),
130+
}
131+
132+
/// Parse a single SQL statement to extract a UDF
133+
fn parse_udf(stmt: SqlStatement) -> DataFusionResult<Parsed> {
134+
match stmt {
135+
SqlStatement::CreateFunction(cf) => {
136+
let function_body = cf.function_body.as_ref();
137+
138+
let language = if let Some(lang) = cf.language.as_ref() {
139+
lang.to_string()
140+
} else {
141+
return Err(DataFusionError::Plan(
142+
"function language is required for UDFs".to_string(),
143+
));
144+
};
145+
146+
let code = match function_body {
147+
Some(body) => extract_function_body(body),
148+
None => Err(DataFusionError::Plan(
149+
"function body is required for UDFs".to_string(),
150+
)),
151+
}?;
152+
153+
Ok(Parsed::Udf {
154+
code: code.to_string(),
155+
language,
156+
})
157+
}
158+
_ => Ok(Parsed::Other(stmt.to_string())),
159+
}
160+
}
161+
162+
/// Extracts the code from the function body, adding it to `code`.
163+
fn extract_function_body(body: &CreateFunctionBody) -> DataFusionResult<&str> {
164+
match body {
165+
CreateFunctionBody::AsAfterOptions(e) | CreateFunctionBody::AsBeforeOptions(e) => {
166+
expression_into_str(e)
167+
}
168+
CreateFunctionBody::Return(_) => Err(DataFusionError::Plan(
169+
"`RETURN` function body not supported for UDFs".to_string(),
170+
)),
171+
}
172+
}
173+
174+
/// Attempt to convert an `Expr` into a `str`
175+
fn expression_into_str(expr: &Expr) -> DataFusionResult<&str> {
176+
match expr {
177+
Expr::Value(v) => match &v.value {
178+
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(s),
179+
_ => Err(DataFusionError::Plan("expected string value".to_string())),
180+
},
181+
_ => Err(DataFusionError::Plan(
182+
"expected value expression".to_string(),
183+
)),
184+
}
185+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
mod python;
22
mod rust;
33
mod test_utils;
4+
mod udf_query;

host/tests/integration_tests/python/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ mod examples;
33
mod inspection;
44
mod runtime;
55
mod state;
6-
mod test_utils;
6+
pub(crate) mod test_utils;
77
mod types;

0 commit comments

Comments
 (0)