Skip to content

Commit 8a03dab

Browse files
committed
feat: fully functional udtfs for all tpch tables
1 parent 57d5ba3 commit 8a03dab

File tree

1 file changed

+190
-27
lines changed

1 file changed

+190
-27
lines changed

src/lib.rs

Lines changed: 190 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,217 @@
1+
use datafusion::arrow::compute::concat_batches;
12
use datafusion::catalog::{TableFunctionImpl, TableProvider};
23
use datafusion::common::{Result, ScalarValue, plan_err};
34
use datafusion::datasource::memory::MemTable;
45
use datafusion_expr::Expr;
56
use std::sync::Arc;
67
use tpchgen_arrow::{NationArrow, RecordBatchIterator};
78

8-
/// Table function that returns the TPCH nation table.
9-
#[derive(Debug)]
10-
pub struct TpchNationFunction {}
9+
/// Defines a table function provider and its implementation using [`tpchgen`]
10+
/// as the data source.
11+
macro_rules! define_tpch_udtf_provider {
12+
($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => {
13+
/// Tablle Function that returns the $TABLE_FUNCTION_NAME table.
14+
///
15+
/// This function is a wrapper around the [`tpchgen`] library and builds
16+
/// a table provider that can be used in a DataFusion query.
17+
///
18+
/// The expected arguments are a float literal for the scale factor,
19+
/// an i64 literal for the part, and an i64 literal for the number of parts.
20+
/// The second and third arguments are optional and will default to 1
21+
/// for both values which tells the generator to generate all parts.
22+
///
23+
/// # Examples
24+
///
25+
/// -- This example generates TPCH nation data with scale factor 1.0.
26+
/// SELECT * FROM tpchgen_nation(1.0);
27+
///
28+
/// -- This example generates TPCH order data with scale factor 10 and generates
29+
/// -- the second part with 5 parts.
30+
/// SELECT * FROM tpchgen_order(10.0, 2, 5);
31+
#[derive(Debug)]
32+
pub struct $TABLE_FUNCTION_NAME {}
1133

12-
impl TableFunctionImpl for TpchNationFunction {
13-
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
14-
let Some(Expr::Literal(ScalarValue::Float64(Some(value)))) = args.get(0) else {
15-
return plan_err!("First argument must be a float literal.");
16-
};
34+
impl $TABLE_FUNCTION_NAME {
35+
/// Returns the name of the table function.
36+
pub fn name() -> &'static str {
37+
stringify!($TABLE_FUNCTION_SQL_NAME)
38+
}
39+
}
1740

18-
// Init the table generator.
19-
let tablegen = tpchgen::generators::NationGenerator::new(*value, 0, 0);
41+
impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
42+
/// Implementation of the UDTF invocation for TPCH table generation
43+
/// using the [`tpchgen`] library.
44+
///
45+
/// The first argument is a float literal that specifies the scale factor.
46+
/// The second argument is the part to generate.
47+
/// The third argument is the number of parts to generate.
48+
///
49+
/// The second and third argument are optional and will default to 1
50+
/// for both values which tells the generator to generate all parts.
51+
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
52+
let Some(Expr::Literal(ScalarValue::Float64(Some(value)))) = args.get(0) else {
53+
return plan_err!("First argument must be a float literal.");
54+
};
2055

21-
// Init the arrow provider.
22-
let mut arrow_tablegen = NationArrow::new(tablegen);
56+
// Default values for part and num_parts.
57+
let part = 1;
58+
let num_parts = 1;
2359

24-
let batch = arrow_tablegen.next().unwrap();
60+
// Check if we have more arguments `part` and `num_parts` respectively
61+
// and if they are i64 literals.
62+
if args.len() > 1 {
63+
// Check if the second argument and third arguments are i64 literals and
64+
// greater than 0.
65+
let Some(Expr::Literal(ScalarValue::Int64(Some(part)))) = args.get(1) else {
66+
return plan_err!("Second argument must be an i64 literal.");
67+
};
68+
let Some(Expr::Literal(ScalarValue::Int64(Some(num_parts)))) = args.get(2)
69+
else {
70+
return plan_err!("Third argument must be an i64 literal.");
71+
};
72+
if *part < 0 || *num_parts < 0 {
73+
return plan_err!("Second and third arguments must be greater than 0.");
74+
}
75+
}
2576

26-
// Build the memtable plan.
27-
let provider = MemTable::try_new(arrow_tablegen.schema().clone(), vec![vec![batch]])?;
77+
// Init the table generator.
78+
let tablegen = <$GENERATOR>::new(*value, part, num_parts);
2879

29-
Ok(Arc::new(provider))
30-
}
80+
// Init the arrow provider.
81+
let mut arrow_tablegen = <$ARROW_GENERATOR>::new(tablegen);
82+
83+
// The arrow provider is a batched generator with a default batch size of 8000
84+
// so to build the full table we need to call `next` until it returns None.
85+
let mut batches = Vec::new();
86+
while let Some(batch) = arrow_tablegen.next() {
87+
batches.push(batch);
88+
}
89+
// Use `concat_batches` to create a single batch from the vector of batches.
90+
// This is needed because the `MemTable` provider requires a single batch.
91+
// This is a bit of a hack, but it works.
92+
let batch = concat_batches(arrow_tablegen.schema(), &batches)?;
93+
94+
// Build the memtable plan.
95+
let provider =
96+
MemTable::try_new(arrow_tablegen.schema().clone(), vec![vec![batch]])?;
97+
98+
Ok(Arc::new(provider))
99+
}
100+
}
101+
};
31102
}
32103

104+
define_tpch_udtf_provider!(
105+
TpchNation,
106+
tpch_nation,
107+
tpchgen::generators::NationGenerator,
108+
NationArrow
109+
);
110+
111+
define_tpch_udtf_provider!(
112+
TpchCustomer,
113+
tpch_customer,
114+
tpchgen::generators::CustomerGenerator,
115+
tpchgen_arrow::CustomerArrow
116+
);
117+
118+
define_tpch_udtf_provider!(
119+
TpchOrders,
120+
tpch_orders,
121+
tpchgen::generators::OrderGenerator,
122+
tpchgen_arrow::OrderArrow
123+
);
124+
125+
define_tpch_udtf_provider!(
126+
TpchLineitem,
127+
tpch_lineitem,
128+
tpchgen::generators::LineItemGenerator,
129+
tpchgen_arrow::LineItemArrow
130+
);
131+
132+
define_tpch_udtf_provider!(
133+
TpchPart,
134+
tpch_part,
135+
tpchgen::generators::PartGenerator,
136+
tpchgen_arrow::PartArrow
137+
);
138+
139+
define_tpch_udtf_provider!(
140+
TpchPartsupp,
141+
tpch_partsupp,
142+
tpchgen::generators::PartSuppGenerator,
143+
tpchgen_arrow::PartSuppArrow
144+
);
145+
146+
define_tpch_udtf_provider!(
147+
TpchSupplier,
148+
tpch_supplier,
149+
tpchgen::generators::SupplierGenerator,
150+
tpchgen_arrow::SupplierArrow
151+
);
152+
153+
define_tpch_udtf_provider!(
154+
TpchRegion,
155+
tpch_region,
156+
tpchgen::generators::RegionGenerator,
157+
tpchgen_arrow::RegionArrow
158+
);
159+
33160
#[cfg(test)]
34161
mod tests {
35162
use super::*;
36163
use datafusion::execution::context::SessionContext;
37164

38165
#[tokio::test]
39-
async fn test_tpchgen_function() -> Result<()> {
166+
async fn test_tpch_functions() -> Result<()> {
40167
let ctx = SessionContext::new();
41-
ctx.register_udtf("tpchgen_nation", Arc::new(TpchNationFunction {}));
42168

43-
let df = ctx
44-
.sql("SELECT * FROM tpchgen_nation(1.0)")
45-
.await?
46-
.collect()
47-
.await?;
169+
// Register all the UDTFs.
170+
ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
171+
ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
172+
ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
173+
ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
174+
ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
175+
ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
176+
ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
177+
ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));
178+
179+
// Test all the UDTFs, the constants were computed using the tpchgen library
180+
// and the expected values are the number of rows and columns for each table.
181+
let test_cases = vec![
182+
(TpchNation::name(), 25, 4),
183+
(TpchCustomer::name(), 150000, 8),
184+
(TpchOrders::name(), 1500000, 9),
185+
(TpchLineitem::name(), 6001215, 16),
186+
(TpchPart::name(), 200000, 9),
187+
(TpchPartsupp::name(), 800000, 5),
188+
(TpchSupplier::name(), 10000, 7),
189+
(TpchRegion::name(), 5, 3),
190+
];
191+
192+
for (function, expected_rows, expected_columns) in test_cases {
193+
let df = ctx
194+
.sql(&format!("SELECT * FROM {}(1.0)", function))
195+
.await?
196+
.collect()
197+
.await?;
48198

49-
assert_eq!(df.len(), 1);
50-
assert_eq!(df[0].num_rows(), 25);
51-
assert_eq!(df[0].num_columns(), 4);
199+
assert_eq!(df.len(), 1);
200+
assert_eq!(
201+
df[0].num_rows(),
202+
expected_rows,
203+
"{}: {}",
204+
function,
205+
expected_rows
206+
);
207+
assert_eq!(
208+
df[0].num_columns(),
209+
expected_columns,
210+
"{}: {}",
211+
function,
212+
expected_columns
213+
);
214+
}
52215
Ok(())
53216
}
54217
}

0 commit comments

Comments
 (0)