Skip to content

Commit 6e11566

Browse files
authored
Merge pull request #1 from clflushopt/cl/feat/tpch-table-func-global
feat: implement a core tpch table func to generate all data
2 parents 886dae3 + 2f64af5 commit 6e11566

File tree

3 files changed

+268
-9
lines changed

3 files changed

+268
-9
lines changed

examples/parquet.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//! Example of using the datafusion-tpch extension to generate TPCH tables
2+
//! and writing them to disk via `COPY`.
3+
4+
use datafusion::prelude::{SessionConfig, SessionContext};
5+
use datafusion_tpch::{register_tpch_udtf, register_tpch_udtfs};
6+
7+
#[tokio::main]
8+
async fn main() -> datafusion::error::Result<()> {
9+
let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true));
10+
register_tpch_udtf(&ctx);
11+
12+
let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?;
13+
sql_df.show().await?;
14+
15+
let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?;
16+
sql_df.show().await?;
17+
18+
let sql_df = ctx
19+
.sql(&format!(
20+
"COPY nation TO './tpch_nation.parquet' STORED AS PARQUET"
21+
))
22+
.await?;
23+
sql_df.show().await?;
24+
25+
register_tpch_udtfs(&ctx)?;
26+
27+
let sql_df = ctx
28+
.sql(&format!(
29+
"COPY (SELECT * FROM tpch_lineitem(1.0)) TO './tpch_lineitem_sf_10.parquet' STORED AS PARQUET"
30+
))
31+
.await?;
32+
sql_df.show().await?;
33+
34+
Ok(())
35+
}

examples/tpchgen.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//! Example of using the datafusion-tpch extension to generate TPCH datasets
2+
//! on the the fly in datafusion.
3+
4+
use datafusion::prelude::{SessionConfig, SessionContext};
5+
use datafusion_tpch::register_tpch_udtf;
6+
7+
#[tokio::main]
8+
async fn main() -> datafusion::error::Result<()> {
9+
let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true));
10+
register_tpch_udtf(&ctx);
11+
12+
let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?;
13+
sql_df.show().await?;
14+
15+
let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?;
16+
sql_df.show().await?;
17+
18+
let sql_df = ctx.sql(&format!("SELECT * FROM nation LIMIT 5;")).await?;
19+
sql_df.show().await?;
20+
21+
let sql_df = ctx.sql(&format!("SELECT * FROM partsupp LIMIT 5;")).await?;
22+
sql_df.show().await?;
23+
24+
let sql_df = ctx.sql(&format!("SELECT * FROM region LIMIT 5;")).await?;
25+
sql_df.show().await?;
26+
27+
let sql_df = ctx.sql(&format!("SELECT * FROM customer LIMIT 5;")).await?;
28+
sql_df.show().await?;
29+
30+
let sql_df = ctx.sql(&format!("SELECT * FROM orders LIMIT 5;")).await?;
31+
sql_df.show().await?;
32+
33+
let sql_df = ctx.sql(&format!("SELECT * FROM lineitem LIMIT 5;")).await?;
34+
sql_df.show().await?;
35+
36+
let sql_df = ctx.sql(&format!("SELECT * FROM part LIMIT 5;")).await?;
37+
sql_df.show().await?;
38+
Ok(())
39+
}

src/lib.rs

Lines changed: 194 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
use datafusion::arrow::compute::concat_batches;
2+
use datafusion::arrow::datatypes::Schema;
23
use datafusion::catalog::{TableFunctionImpl, TableProvider};
34
use datafusion::common::{Result, ScalarValue, plan_err};
45
use datafusion::datasource::memory::MemTable;
56
use datafusion::prelude::SessionContext;
7+
use datafusion::sql::TableReference;
68
use datafusion_expr::Expr;
9+
use std::fmt::Debug;
710
use std::sync::Arc;
811
use tpchgen_arrow::RecordBatchIterator;
912

1013
/// Defines a table function provider and its implementation using [`tpchgen`]
1114
/// as the data source.
1215
macro_rules! define_tpch_udtf_provider {
1316
($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => {
14-
#[doc = concat!(
15-
"A table function that generates the `",
16-
stringify!($TABLE_FUNCTION_SQL_NAME),
17-
"` table using the `tpchgen` library."
18-
)]
17+
#[doc = concat!("A table function that generates the `",stringify!($TABLE_FUNCTION_SQL_NAME),"` table using the `tpchgen` library.")]
1918
///
2019
/// The expected arguments are a float literal for the scale factor,
2120
/// an i64 literal for the part, and an i64 literal for the number of parts.
@@ -59,6 +58,19 @@ macro_rules! define_tpch_udtf_provider {
5958
pub fn name() -> &'static str {
6059
stringify!($TABLE_FUNCTION_SQL_NAME)
6160
}
61+
62+
/// Returns the name of the table generated by the table function
63+
/// when used in a SQL query.
64+
pub fn table_name() -> &'static str {
65+
stringify!($TABLE_FUNCTION_SQL_NAME)
66+
.strip_prefix("tpch_")
67+
.unwrap_or_else(|| {
68+
panic!(
69+
"Table function name {} does not start with tpch_",
70+
stringify!($TABLE_FUNCTION_SQL_NAME)
71+
)
72+
})
73+
}
6274
}
6375

6476
impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
@@ -194,6 +206,122 @@ pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
194206
Ok(())
195207
}
196208

209+
/// Table function provider for TPCH tables.
210+
struct TpchTables {
211+
ctx: SessionContext,
212+
}
213+
214+
impl TpchTables {
215+
const TPCH_TABLE_NAMES: &[&str] = &[
216+
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier", "region",
217+
];
218+
/// Creates a new TPCH table provider.
219+
pub fn new(ctx: SessionContext) -> Self {
220+
Self { ctx }
221+
}
222+
223+
/// Build and register a TPCH table by it's table function provider.
224+
fn build_and_register_tpch_table<P: TableFunctionImpl>(
225+
&self,
226+
provider: P,
227+
table_name: &str,
228+
scale_factor: f64,
229+
) -> Result<()> {
230+
let table = provider
231+
.call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?;
232+
self.ctx
233+
.register_table(TableReference::bare(table_name), table)?;
234+
235+
Ok(())
236+
}
237+
238+
/// Build and register all TPCH tables in the session context.
239+
fn build_and_register_all_tables(&self, scale_factor: f64) -> Result<()> {
240+
for &suffix in Self::TPCH_TABLE_NAMES {
241+
match suffix {
242+
"nation" => {
243+
self.build_and_register_tpch_table(TpchNation {}, suffix, scale_factor)?
244+
}
245+
"customer" => {
246+
self.build_and_register_tpch_table(TpchCustomer {}, suffix, scale_factor)?
247+
}
248+
"orders" => {
249+
self.build_and_register_tpch_table(TpchOrders {}, suffix, scale_factor)?
250+
}
251+
"lineitem" => {
252+
self.build_and_register_tpch_table(TpchLineitem {}, suffix, scale_factor)?
253+
}
254+
"part" => self.build_and_register_tpch_table(TpchPart {}, suffix, scale_factor)?,
255+
"partsupp" => {
256+
self.build_and_register_tpch_table(TpchPartsupp {}, suffix, scale_factor)?
257+
}
258+
"supplier" => {
259+
self.build_and_register_tpch_table(TpchSupplier {}, suffix, scale_factor)?
260+
}
261+
"region" => {
262+
self.build_and_register_tpch_table(TpchRegion {}, suffix, scale_factor)?
263+
}
264+
_ => unreachable!("Unknown TPCH table suffix: {}", suffix), // Should not happen
265+
}
266+
}
267+
Ok(())
268+
}
269+
}
270+
271+
// Implement the `TableProvider` trait for the `TpchTableProvider`, we need
272+
// to do it manually because the `SessionContext` does not implement it.
273+
impl Debug for TpchTables {
274+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275+
write!(f, "TpchTableProvider")
276+
}
277+
}
278+
279+
impl TableFunctionImpl for TpchTables {
280+
/// The `call` method is the entry point for the UDTF and is called when the UDTF is
281+
/// invoked in a SQL query.
282+
///
283+
/// The UDF requires one argument, the scale factor, and allows a second optional
284+
/// argument which is a path on disk. If a path is specified, the data is flushed
285+
/// to disk from the generated memory table.
286+
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
287+
let scale_factor = match args.first() {
288+
Some(Expr::Literal(ScalarValue::Float64(Some(value)))) => *value,
289+
_ => {
290+
return plan_err!(
291+
"First argument must be a float literal that specifies the scale factor."
292+
);
293+
}
294+
};
295+
296+
// Register the TPCH tables in the session context.
297+
self.build_and_register_all_tables(scale_factor)?;
298+
299+
// Create a table with the schema |table_name| and the data is just the
300+
// individual table names.
301+
let schema = Schema::new(vec![datafusion::arrow::datatypes::Field::new(
302+
"table_name",
303+
datafusion::arrow::datatypes::DataType::Utf8,
304+
false,
305+
)]);
306+
let batch = datafusion::arrow::record_batch::RecordBatch::try_new(
307+
Arc::new(schema.clone()),
308+
vec![Arc::new(datafusion::arrow::array::StringArray::from(vec![
309+
"nation", "customer", "orders", "lineitem", "part", "partsupp", "supplier",
310+
"region",
311+
]))],
312+
)?;
313+
let mem_table = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
314+
315+
Ok(Arc::new(mem_table))
316+
}
317+
}
318+
319+
/// Register the `tpch` table function.
320+
pub fn register_tpch_udtf(ctx: &SessionContext) {
321+
let tpch_udtf = TpchTables::new(ctx.clone());
322+
ctx.register_udtf("tpch", Arc::new(tpch_udtf));
323+
}
324+
197325
#[cfg(test)]
198326
mod tests {
199327
use super::*;
@@ -203,12 +331,15 @@ mod tests {
203331
async fn test_register_all_tpch_functions() -> Result<()> {
204332
let ctx = SessionContext::new();
205333

334+
let tpch_tbl_fn = TpchTables::new(ctx.clone());
335+
ctx.register_udtf("tcph", Arc::new(tpch_tbl_fn));
336+
206337
// Register all the UDTFs.
207338
register_tpch_udtfs(&ctx)?;
208339

209340
// Test all the UDTFs, the constants were computed using the tpchgen library
210341
// and the expected values are the number of rows and columns for each table.
211-
let test_cases = vec![
342+
let expected_tables = vec![
212343
(TpchNation::name(), 25, 4),
213344
(TpchCustomer::name(), 150000, 8),
214345
(TpchOrders::name(), 1500000, 9),
@@ -219,7 +350,7 @@ mod tests {
219350
(TpchRegion::name(), 5, 3),
220351
];
221352

222-
for (function, expected_rows, expected_columns) in test_cases {
353+
for (function, expected_rows, expected_columns) in expected_tables {
223354
let df = ctx
224355
.sql(&format!("SELECT * FROM {}(1.0)", function))
225356
.await?
@@ -261,7 +392,7 @@ mod tests {
261392

262393
// Test all the UDTFs, the constants were computed using the tpchgen library
263394
// and the expected values are the number of rows and columns for each table.
264-
let test_cases = vec![
395+
let expected_tables = vec![
265396
(TpchNation::name(), 25, 4),
266397
(TpchCustomer::name(), 150000, 8),
267398
(TpchOrders::name(), 1500000, 9),
@@ -272,7 +403,7 @@ mod tests {
272403
(TpchRegion::name(), 5, 3),
273404
];
274405

275-
for (function, expected_rows, expected_columns) in test_cases {
406+
for (function, expected_rows, expected_columns) in expected_tables {
276407
let df = ctx
277408
.sql(&format!("SELECT * FROM {}(1.0)", function))
278409
.await?
@@ -297,4 +428,58 @@ mod tests {
297428
}
298429
Ok(())
299430
}
431+
432+
#[tokio::test]
433+
async fn test_register_tpch_provider() -> Result<()> {
434+
let ctx = SessionContext::new();
435+
436+
register_tpch_udtf(&ctx);
437+
438+
// Test the TPCH provider.
439+
let df = ctx
440+
.sql("SELECT * FROM tpch(1.0, '')")
441+
.await?
442+
.collect()
443+
.await?;
444+
445+
assert_eq!(df.len(), 1);
446+
assert_eq!(df[0].num_rows(), 8);
447+
assert_eq!(df[0].num_columns(), 1);
448+
449+
let expected_tables = vec![
450+
(TpchNation::table_name(), 25, 4),
451+
(TpchCustomer::table_name(), 150000, 8),
452+
(TpchOrders::table_name(), 1500000, 9),
453+
(TpchLineitem::table_name(), 6001215, 16),
454+
(TpchPart::table_name(), 200000, 9),
455+
(TpchPartsupp::table_name(), 800000, 5),
456+
(TpchSupplier::table_name(), 10000, 7),
457+
(TpchRegion::table_name(), 5, 3),
458+
];
459+
460+
for (function, expected_rows, expected_columns) in expected_tables {
461+
let df = ctx
462+
.sql(&format!("SELECT * FROM {}", function))
463+
.await?
464+
.collect()
465+
.await?;
466+
467+
assert_eq!(df.len(), 1);
468+
assert_eq!(
469+
df[0].num_rows(),
470+
expected_rows,
471+
"{}: {}",
472+
function,
473+
expected_rows
474+
);
475+
assert_eq!(
476+
df[0].num_columns(),
477+
expected_columns,
478+
"{}: {}",
479+
function,
480+
expected_columns
481+
);
482+
}
483+
Ok(())
484+
}
300485
}

0 commit comments

Comments
 (0)