Skip to content

Commit 6a9e421

Browse files
authored
Merge pull request #151 from influxdata/tm/really-format-the-snake
feat: strip indentation from source code
2 parents cb4233f + 37fce55 commit 6a9e421

File tree

3 files changed

+267
-17
lines changed

3 files changed

+267
-17
lines changed

query/src/format/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//! Module for UDF code formatting implementations
2+
3+
/// Trait for formatting UDF code before compilation allows for
4+
/// language-specific formatting or preprocessing.
5+
pub trait UdfCodeFormatter: std::fmt::Debug + Send + Sync {
6+
/// Format the given UDF code string
7+
fn format(&self, code: String) -> String;
8+
}
9+
10+
/// Default implementation that returns code unchanged
11+
#[derive(Debug, Default, Clone, Copy)]
12+
pub struct NoOpFormatter;
13+
14+
impl UdfCodeFormatter for NoOpFormatter {
15+
fn format(&self, code: String) -> String {
16+
code
17+
}
18+
}
19+
20+
/// Code formatter that strips leading indentation
21+
#[derive(Debug, Default, Clone, Copy)]
22+
pub struct StripIndentationFormatter;
23+
24+
impl UdfCodeFormatter for StripIndentationFormatter {
25+
fn format(&self, code: String) -> String {
26+
strip_indentation(&code)
27+
}
28+
}
29+
30+
/// Strips common leading indentation from all non-empty lines in the code string.
31+
fn strip_indentation(code: &str) -> String {
32+
let indent = code
33+
.lines()
34+
.filter(|l| !l.trim().is_empty())
35+
.map(|l| l.chars().take_while(|s| s.is_ascii_whitespace()).count())
36+
.min()
37+
.unwrap_or_default();
38+
39+
code.lines()
40+
.flat_map(|l| l.chars().skip(indent).chain(std::iter::once('\n')))
41+
.collect::<String>()
42+
}

query/src/lib.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ use sqlparser::dialect::dialect_from_str;
1212
use datafusion_udf_wasm_host::{WasmComponentPrecompiled, WasmPermissions, WasmScalarUdf};
1313
use tokio::runtime::Handle;
1414

15+
use crate::format::UdfCodeFormatter;
16+
17+
/// Module for UDF code formatting implementations
18+
pub mod format;
19+
20+
/// Represents a supported UDF language with its associated WASM component
21+
/// and code formatter.
22+
#[derive(Debug)]
23+
pub struct Lang<'a> {
24+
/// Pre-compiled WASM component for the language
25+
pub component: &'a WasmComponentPrecompiled,
26+
/// Code formatter for the language
27+
pub formatter: Box<dyn UdfCodeFormatter>,
28+
}
29+
1530
/// A [ParsedQuery] contains the extracted UDFs and SQL query string
1631
#[derive(Debug)]
1732
pub struct ParsedQuery {
@@ -24,9 +39,9 @@ pub struct ParsedQuery {
2439
/// Handles the registration and invocation of UDF queries in DataFusion with a
2540
/// pre-compiled WASM component.
2641
pub struct UdfQueryParser<'a> {
27-
/// Pre-compiled WASM component.
28-
/// Necessary to create UDFs.
29-
components: HashMap<String, &'a WasmComponentPrecompiled>,
42+
/// Map of strings (eg "python") to supported UDF languages and their WASM
43+
/// components
44+
components: HashMap<String, Lang<'a>>,
3045
}
3146

3247
impl std::fmt::Debug for UdfQueryParser<'_> {
@@ -40,7 +55,7 @@ impl std::fmt::Debug for UdfQueryParser<'_> {
4055

4156
impl<'a> UdfQueryParser<'a> {
4257
/// Registers the UDF query in DataFusion.
43-
pub fn new(components: HashMap<String, &'a WasmComponentPrecompiled>) -> Self {
58+
pub fn new(components: HashMap<String, Lang<'a>>) -> Self {
4459
Self { components }
4560
}
4661

@@ -56,15 +71,18 @@ impl<'a> UdfQueryParser<'a> {
5671

5772
let mut udfs = vec![];
5873
for (lang, blocks) in code {
59-
let component = self.components.get(&lang).ok_or_else(|| {
74+
let lang = self.components.get(&lang).ok_or_else(|| {
6075
DataFusionError::Plan(format!(
6176
"no WASM component registered for language: {:?}",
6277
lang
6378
))
6479
})?;
6580

6681
for code in blocks {
67-
udfs.extend(WasmScalarUdf::new(component, permissions, io_rt.clone(), code).await?);
82+
let code = lang.formatter.format(code);
83+
udfs.extend(
84+
WasmScalarUdf::new(lang.component, permissions, io_rt.clone(), code).await?,
85+
);
6886
}
6987
}
7088

query/tests/integration.rs

Lines changed: 201 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ use datafusion::{
1313
};
1414
use datafusion_common::{Result as DataFusionResult, test_util::batches_to_string};
1515
use datafusion_udf_wasm_host::WasmPermissions;
16-
use datafusion_udf_wasm_query::{ParsedQuery, UdfQueryParser};
16+
use datafusion_udf_wasm_query::{
17+
Lang, ParsedQuery, UdfQueryParser,
18+
format::{NoOpFormatter, StripIndentationFormatter},
19+
};
1720
use tokio::runtime::Handle;
1821

1922
mod integration_tests;
@@ -52,8 +55,15 @@ SELECT add_one(1);
5255

5356
let ctx = SessionContext::new();
5457
let component = python_component().await;
55-
56-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
58+
let formatter = Box::new(NoOpFormatter);
59+
60+
let parser = UdfQueryParser::new(HashMap::from_iter([(
61+
"python".to_string(),
62+
Lang {
63+
component,
64+
formatter,
65+
},
66+
)]));
5767
let parsed_query = parser
5868
.parse(
5969
query,
@@ -101,8 +111,15 @@ SELECT add_one(1), multiply_two(3);
101111

102112
let ctx = SessionContext::new();
103113
let component = python_component().await;
104-
105-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
114+
let formatter = Box::new(NoOpFormatter);
115+
116+
let parser = UdfQueryParser::new(HashMap::from_iter([(
117+
"python".to_string(),
118+
Lang {
119+
component,
120+
formatter,
121+
},
122+
)]));
106123
let parsed_query = parser
107124
.parse(
108125
query,
@@ -146,8 +163,15 @@ SELECT add_one(1), multiply_two(3);
146163

147164
let ctx = SessionContext::new();
148165
let component = python_component().await;
149-
150-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
166+
let formatter = Box::new(NoOpFormatter);
167+
168+
let parser = UdfQueryParser::new(HashMap::from_iter([(
169+
"python".to_string(),
170+
Lang {
171+
component,
172+
formatter,
173+
},
174+
)]));
151175
let parsed_query = parser
152176
.parse(
153177
query,
@@ -185,8 +209,15 @@ SELECT add_one(1)
185209

186210
let ctx = SessionContext::new();
187211
let component = python_component().await;
188-
189-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
212+
let formatter = Box::new(NoOpFormatter);
213+
214+
let parser = UdfQueryParser::new(HashMap::from_iter([(
215+
"python".to_string(),
216+
Lang {
217+
component,
218+
formatter,
219+
},
220+
)]));
190221
let parsed_query = parser
191222
.parse(
192223
query,
@@ -219,8 +250,15 @@ EXPLAIN SELECT add_one(1);
219250

220251
let ctx = SessionContext::new();
221252
let component = python_component().await;
222-
223-
let parser = UdfQueryParser::new(HashMap::from_iter([("python".to_string(), component)]));
253+
let formatter = Box::new(NoOpFormatter);
254+
255+
let parser = UdfQueryParser::new(HashMap::from_iter([(
256+
"python".to_string(),
257+
Lang {
258+
component,
259+
formatter,
260+
},
261+
)]));
224262
let parsed_query = parser
225263
.parse(
226264
query,
@@ -248,3 +286,155 @@ EXPLAIN SELECT add_one(1);
248286
+---------------+--------------------------------------------------------+
249287
");
250288
}
289+
290+
#[tokio::test(flavor = "multi_thread")]
291+
async fn test_strip_indentation_everything_indented() {
292+
let query_lines = &[
293+
" CREATE FUNCTION add_one()",
294+
" LANGUAGE python",
295+
" AS '",
296+
" def add_one(x: int) -> int:",
297+
" ",
298+
" return x + 1",
299+
" ';",
300+
" ",
301+
" SELECT add_one(1);",
302+
];
303+
let query = query_lines.join("\n");
304+
305+
let ctx = SessionContext::new();
306+
let component = python_component().await;
307+
let formatter = Box::new(StripIndentationFormatter);
308+
309+
let parser = UdfQueryParser::new(HashMap::from_iter([(
310+
"python".to_string(),
311+
Lang {
312+
component,
313+
formatter,
314+
},
315+
)]));
316+
let parsed_query = parser
317+
.parse(
318+
&query,
319+
&WasmPermissions::new(),
320+
Handle::current(),
321+
ctx.task_ctx().as_ref(),
322+
)
323+
.await
324+
.unwrap();
325+
326+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
327+
let batch = df.collect().await.unwrap();
328+
329+
assert_batches_eq!(
330+
[
331+
"+-------------------+",
332+
"| add_one(Int64(1)) |",
333+
"+-------------------+",
334+
"| 2 |",
335+
"+-------------------+",
336+
],
337+
&batch
338+
);
339+
}
340+
341+
#[tokio::test(flavor = "multi_thread")]
342+
async fn test_strip_indentation_empty_lines_not_indented() {
343+
let query_lines = &[
344+
" CREATE FUNCTION add_one()",
345+
" LANGUAGE python",
346+
" AS '",
347+
" def add_one(x: int) -> int:",
348+
"",
349+
" return x + 1",
350+
" ';",
351+
"",
352+
" SELECT add_one(1);",
353+
];
354+
let query = query_lines.join("\n");
355+
356+
let ctx = SessionContext::new();
357+
let component = python_component().await;
358+
let formatter = Box::new(StripIndentationFormatter);
359+
360+
let parser = UdfQueryParser::new(HashMap::from_iter([(
361+
"python".to_string(),
362+
Lang {
363+
component,
364+
formatter,
365+
},
366+
)]));
367+
let parsed_query = parser
368+
.parse(
369+
&query,
370+
&WasmPermissions::new(),
371+
Handle::current(),
372+
ctx.task_ctx().as_ref(),
373+
)
374+
.await
375+
.unwrap();
376+
377+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
378+
let batch = df.collect().await.unwrap();
379+
380+
assert_batches_eq!(
381+
[
382+
"+-------------------+",
383+
"| add_one(Int64(1)) |",
384+
"+-------------------+",
385+
"| 2 |",
386+
"+-------------------+",
387+
],
388+
&batch
389+
);
390+
}
391+
392+
#[tokio::test(flavor = "multi_thread")]
393+
async fn test_strip_indentation_python_further_indented() {
394+
let query_lines = &[
395+
" CREATE FUNCTION add_one()",
396+
" LANGUAGE python",
397+
" AS '",
398+
" def add_one(x: int) -> int:",
399+
" return x + 1",
400+
" ';",
401+
" ",
402+
" SELECT add_one(1);",
403+
];
404+
let query = query_lines.join("\n");
405+
406+
let ctx = SessionContext::new();
407+
let component = python_component().await;
408+
let formatter = Box::new(StripIndentationFormatter);
409+
410+
let parser = UdfQueryParser::new(HashMap::from_iter([(
411+
"python".to_string(),
412+
Lang {
413+
component,
414+
formatter,
415+
},
416+
)]));
417+
let parsed_query = parser
418+
.parse(
419+
&query,
420+
&WasmPermissions::new(),
421+
Handle::current(),
422+
ctx.task_ctx().as_ref(),
423+
)
424+
.await
425+
.unwrap();
426+
427+
let df = UdfQueryInvocator::invoke(&ctx, parsed_query).await.unwrap();
428+
let batch = df.collect().await.unwrap();
429+
430+
assert_batches_eq!(
431+
[
432+
"+-------------------+",
433+
"| add_one(Int64(1)) |",
434+
"+-------------------+",
435+
"| 2 |",
436+
"+-------------------+",
437+
],
438+
&batch
439+
);
440+
}

0 commit comments

Comments
 (0)