Skip to content

Commit d3d4754

Browse files
committed
feat: create UdfCodeFormatter trait and implement it for StripIndentationFormatter
1 parent cb4233f commit d3d4754

File tree

3 files changed

+245
-15
lines changed

3 files changed

+245
-15
lines changed

query/src/format/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//! Module for UDF code formatting implementations
2+
3+
/// Trait for formatting UDF code before compilation allows for
4+
/// language-specific formatting or preprocessing.
5+
pub trait UdfCodeFormatter: std::fmt::Debug + Send + Sync {
6+
/// Format the given UDF code string
7+
fn format(&self, code: String) -> String;
8+
}
9+
10+
/// Default implementation that returns code unchanged
11+
#[derive(Debug, Clone, Copy)]
12+
pub struct NoOpFormatter;
13+
14+
impl UdfCodeFormatter for NoOpFormatter {
15+
fn format(&self, code: String) -> String {
16+
code
17+
}
18+
}
19+
20+
/// Code formatter that strips leading indentation
21+
#[derive(Debug, Clone, Copy)]
22+
pub struct StripIndentationFormatter;
23+
24+
impl UdfCodeFormatter for StripIndentationFormatter {
25+
fn format(&self, code: String) -> String {
26+
strip_indentation(&code)
27+
}
28+
}
29+
30+
/// Strips common leading indentation from all non-empty lines in the code string.
31+
fn strip_indentation(code: &str) -> String {
32+
let indent = code
33+
.lines()
34+
.filter(|l| !l.trim().is_empty())
35+
.map(|l| l.chars().take_while(|s| s.is_ascii_whitespace()).count())
36+
.min()
37+
.unwrap_or_default();
38+
39+
code.lines()
40+
.flat_map(|l| l.chars().skip(indent).chain(std::iter::once('\n')))
41+
.collect::<String>()
42+
}

query/src/lib.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ use sqlparser::dialect::dialect_from_str;
1212
use datafusion_udf_wasm_host::{WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf};
1313
use tokio::runtime::Handle;
1414

15+
use crate::format::UdfCodeFormatter;
16+
17+
/// Module for UDF code formatting implementations
18+
pub mod format;
19+
20+
/// Represents a supported UDF language with its associated WASM component
21+
/// and code formatter.
22+
#[derive(Debug)]
23+
pub struct Lang<'a> {
24+
/// Pre-compiled WASM component for the language
25+
pub component: &'a WasmComponentPrecompiled,
26+
/// Code formatter for the language
27+
pub formatter: Box<dyn UdfCodeFormatter>,
28+
}
29+
1530
/// A [ParsedQuery] contains the extracted UDFs and SQL query string
1631
#[derive(Debug)]
1732
pub struct ParsedQuery {
@@ -26,7 +41,7 @@ pub struct ParsedQuery {
2641
pub struct UdfQueryParser<'a> {
2742
/// Pre-compiled WASM component.
2843
/// Necessary to create UDFs.
29-
components: HashMap<String, &'a WasmComponentPrecompiled>,
44+
components: HashMap<String, Lang<'a>>,
3045
}
3146

3247
impl std::fmt::Debug for UdfQueryParser<'_> {
@@ -40,7 +55,7 @@ impl std::fmt::Debug for UdfQueryParser<'_> {
4055

4156
impl<'a> UdfQueryParser<'a> {
4257
/// Registers the UDF query in DataFusion.
43-
pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self {
58+
pub fn new(components: HashMap<String, Lang<'a>>) -> Self {
4459
Self { components }
4560
}
4661

@@ -56,15 +71,16 @@ impl<'a> UdfQueryParser<'a> {
5671

5772
let mut udfs = vec![];
5873
for (lang, blocks) in code {
59-
let component = self.components.get(&lang).ok_or_else(|| {
74+
let lang = self.components.get(&lang).ok_or_else(|| {
6075
DataFusionError::Plan(format!(
6176
"no WASM component registered for language: {:?}",
6277
lang
6378
))
6479
})?;
6580

6681
for code in blocks {
67-
udfs.extend(WasmScalarUdf::new(component, permissions, io_rt.clone(), code).await?);
82+
let code = lang.formatter.format(code);
83+
udfs.extend(WasmScalarUdf::new(lang.component, permissions, io_rt.clone(), code).await?);
6884
}
6985
}
7086

query/tests/integration.rs

Lines changed: 183 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ use datafusion::{
1313
};
1414
use datafusion_common::{Result as DataFusionResult, test_util::batches_to_string};
1515
use datafusion_udf_wasm_host::WasmPermissions;
16-
use datafusion_udf_wasm_query::{ParsedQuery, UdfQueryParser};
16+
use datafusion_udf_wasm_query::{
17+
Lang, ParsedQuery, UdfQueryParser,
18+
format::{NoOpFormatter, StripIndentationFormatter},
19+
};
1720
use tokio::runtime::Handle;
1821

1922
mod integration_tests;
@@ -52,8 +55,15 @@ SELECT add_one(1);
5255

5356
let ctx = SessionContext::new();
5457
let component = python_component().await;
55-
56-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
58+
let formatter = Box::new(NoOpFormatter);
59+
60+
let parser = UdfQueryParser::new(HashMap::from_iter([(
61+
"python".to_string(),
62+
Lang {
63+
component,
64+
formatter,
65+
},
66+
)]));
5767
let parsed_query = parser
5868
.parse(
5969
query,
@@ -101,8 +111,15 @@ SELECT add_one(1), multiply_two(3);
101111

102112
let ctx = SessionContext::new();
103113
let component = python_component().await;
104-
105-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
114+
let formatter = Box::new(NoOpFormatter);
115+
116+
let parser = UdfQueryParser::new(HashMap::from_iter([(
117+
"python".to_string(),
118+
Lang {
119+
component,
120+
formatter,
121+
},
122+
)]));
106123
let parsed_query = parser
107124
.parse(
108125
query,
@@ -146,8 +163,15 @@ SELECT add_one(1), multiply_two(3);
146163

147164
let ctx = SessionContext::new();
148165
let component = python_component().await;
149-
150-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
166+
let formatter = Box::new(NoOpFormatter);
167+
168+
let parser = UdfQueryParser::new(HashMap::from_iter([(
169+
"python".to_string(),
170+
Lang {
171+
component,
172+
formatter,
173+
},
174+
)]));
151175
let parsed_query = parser
152176
.parse(
153177
query,
@@ -185,8 +209,15 @@ SELECT add_one(1)
185209

186210
let ctx = SessionContext::new();
187211
let component = python_component().await;
188-
189-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
212+
let formatter = Box::new(NoOpFormatter);
213+
214+
let parser = UdfQueryParser::new(HashMap::from_iter([(
215+
"python".to_string(),
216+
Lang {
217+
component,
218+
formatter,
219+
},
220+
)]));
190221
let parsed_query = parser
191222
.parse(
192223
query,
@@ -219,8 +250,15 @@ EXPLAIN SELECT add_one(1);
219250

220251
let ctx = SessionContext::new();
221252
let component = python_component().await;
222-
223-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
253+
let formatter = Box::new(NoOpFormatter);
254+
255+
let parser = UdfQueryParser::new(HashMap::from_iter([(
256+
"python".to_string(),
257+
Lang {
258+
component,
259+
formatter,
260+
},
261+
)]));
224262
let parsed_query = parser
225263
.parse(
226264
query,
@@ -248,3 +286,137 @@ EXPLAIN SELECT add_one(1);
248286
+---------------+--------------------------------------------------------+
249287
");
250288
}
289+
290+
#[tokio::test(flavor = "multi_thread")]
291+
async fn test_strip_indentation_everything_indented() {
292+
let query = r#"
293+
CREATE FUNCTION add_one()
294+
LANGUAGE python
295+
AS '
296+
def add_one(x: int) -> int:
297+
298+
return x + 1
299+
';
300+
301+
SELECT add_one(1);
302+
"#;
303+
304+
let ctx = SessionContext::new();
305+
let component = python_component().await;
306+
let formatter = Box::new(StripIndentationFormatter);
307+
308+
let parser = UdfQueryParser::new(HashMap::from_iter([(
309+
"python".to_string(),
310+
Lang {
311+
component,
312+
formatter,
313+
},
314+
)]));
315+
let parsed_query = parser
316+
.parse(query, &WasmPermissions::new(), Handle::current(), ctx.task_ctx().as_ref())
317+
.await
318+
.unwrap();
319+
320+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
321+
let batch = df.collect().await.unwrap();
322+
323+
assert_batches_eq!(
324+
[
325+
"+-------------------+",
326+
"| add_one(Int64(1)) |",
327+
"+-------------------+",
328+
"| 2 |",
329+
"+-------------------+",
330+
],
331+
&batch
332+
);
333+
}
334+
335+
#[tokio::test(flavor = "multi_thread")]
336+
async fn test_strip_indentation_empty_lines_not_indented() {
337+
let query = r#"
338+
CREATE FUNCTION add_one()
339+
LANGUAGE python
340+
AS '
341+
def add_one(x: int) -> int:
342+
343+
return x + 1
344+
';
345+
346+
SELECT add_one(1);
347+
"#;
348+
349+
let ctx = SessionContext::new();
350+
let component = python_component().await;
351+
let formatter = Box::new(StripIndentationFormatter);
352+
353+
let parser = UdfQueryParser::new(HashMap::from_iter([(
354+
"python".to_string(),
355+
Lang {
356+
component,
357+
formatter,
358+
},
359+
)]));
360+
let parsed_query = parser
361+
.parse(query, &WasmPermissions::new(), Handle::current(), ctx.task_ctx().as_ref())
362+
.await
363+
.unwrap();
364+
365+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
366+
let batch = df.collect().await.unwrap();
367+
368+
assert_batches_eq!(
369+
[
370+
"+-------------------+",
371+
"| add_one(Int64(1)) |",
372+
"+-------------------+",
373+
"| 2 |",
374+
"+-------------------+",
375+
],
376+
&batch
377+
);
378+
}
379+
380+
#[tokio::test(flavor = "multi_thread")]
381+
async fn test_strip_indentation_python_further_indented() {
382+
let query = r#"
383+
CREATE FUNCTION add_one()
384+
LANGUAGE python
385+
AS '
386+
def add_one(x: int) -> int:
387+
return x + 1
388+
';
389+
390+
SELECT add_one(1);
391+
"#;
392+
393+
let ctx = SessionContext::new();
394+
let component = python_component().await;
395+
let formatter = Box::new(StripIndentationFormatter);
396+
397+
let parser = UdfQueryParser::new(HashMap::from_iter([(
398+
"python".to_string(),
399+
Lang {
400+
component,
401+
formatter,
402+
},
403+
)]));
404+
let parsed_query = parser
405+
.parse(query, &WasmPermissions::new(), Handle::current(), ctx.task_ctx().as_ref())
406+
.await
407+
.unwrap();
408+
409+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
410+
let batch = df.collect().await.unwrap();
411+
412+
assert_batches_eq!(
413+
[
414+
"+-------------------+",
415+
"| add_one(Int64(1)) |",
416+
"+-------------------+",
417+
"| 2 |",
418+
"+-------------------+",
419+
],
420+
&batch
421+
);
422+
}

0 commit comments

Comments
 (0)