Skip to content

Commit 443ada1

Browse files
committed
fix: improve readme and doc comments
1 parent 06181ed commit 443ada1

File tree

2 files changed

+128
-12
lines changed

2 files changed

+128
-12
lines changed

README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,45 @@
11
# datafusion-tpch
22

3+
[![Apache licensed][license-badge]][license-url]
4+
[![Build Status][actions-badge]][actions-url]
5+
6+
[license-badge]: https://img.shields.io/badge/license-Apache%20v2-blue.svg
7+
[license-url]: https://github.com/clflushopt/datafusion-tpch/blob/main/LICENSE
8+
[actions-badge]: https://github.com/clflushopt/datafusion-tpch/actions/workflows/rust.yml/badge.svg
9+
[actions-url]: https://github.com/clflushopt/datafusion-tpch/actions?query=branch%3Amain
10+
311
Note: This is not an official Apache Software Foundation release.
412

513
This crate provides functions to generate the TPCH benchmark dataset for Datafusion
614
using the [tpchgen](https://github.com/clflushopt/tpchgen-rs) crates.
715

16+
## Usage
17+
18+
The `datafusion-tpch` crate offers two possible ways to register the TPCH individual
19+
table functions.
20+
21+
```rust
22+
#[tokio::main]
23+
async fn main() -> Result<()> {
24+
// create local execution context
25+
let ctx = SessionContext::new();
26+
27+
// Register all the UDTFs.
28+
ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
29+
ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
30+
ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
31+
ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
32+
ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
33+
ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
34+
ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
35+
ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));
36+
37+
// Generate the nation table with a scale factor of 1.
38+
let df = ctx
39+
.sql(format!("SELECT * FROM tpch_nation(1.0);").as_str())
40+
.await?;
41+
df.show().await?;
42+
Ok(())
43+
}
44+
```
45+

src/lib.rs

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use datafusion::arrow::compute::concat_batches;
22
use datafusion::catalog::{TableFunctionImpl, TableProvider};
33
use datafusion::common::{Result, ScalarValue, plan_err};
44
use datafusion::datasource::memory::MemTable;
5+
use datafusion::prelude::SessionContext;
56
use datafusion_expr::Expr;
67
use std::sync::Arc;
78
use tpchgen_arrow::RecordBatchIterator;
@@ -10,24 +11,41 @@ use tpchgen_arrow::RecordBatchIterator;
1011
/// as the data source.
1112
macro_rules! define_tpch_udtf_provider {
1213
($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => {
13-
/// Tablle Function that returns the $TABLE_FUNCTION_NAME table.
14-
///
15-
/// This function is a wrapper around the [`tpchgen`] library and builds
16-
/// a table provider that can be used in a DataFusion query.
14+
#[doc = concat!(
15+
"A table function that generates the `",
16+
stringify!($TABLE_FUNCTION_SQL_NAME),
17+
"` table using the `tpchgen` library."
18+
)]
1719
///
1820
/// The expected arguments are a float literal for the scale factor,
1921
/// an i64 literal for the part, and an i64 literal for the number of parts.
2022
/// The second and third arguments are optional and will default to 1
2123
/// for both values which tells the generator to generate all parts.
2224
///
2325
/// # Examples
24-
///
25-
/// -- This example generates TPCH nation data with scale factor 1.0.
26-
/// SELECT * FROM tpchgen_nation(1.0);
27-
///
28-
/// -- This example generates TPCH order data with scale factor 10 and generates
29-
/// -- the second part with 5 parts.
30-
/// SELECT * FROM tpchgen_order(10.0, 2, 5);
26+
/// ```
27+
/// #[tokio::main]
28+
/// async fn main() -> Result<()> {
29+
/// // create local execution context
30+
/// let ctx = SessionContext::new();
31+
32+
/// // Register all the UDTFs.
33+
/// ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
34+
/// ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
35+
/// ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
36+
/// ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
37+
/// ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
38+
/// ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
39+
/// ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
40+
/// ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));
41+
/// // Generate the nation table with a scale factor of 1.
42+
/// let df = ctx
43+
/// .sql(format!("SELECT * FROM tpch_nation(1.0);").as_str())
44+
/// .await?;
45+
/// df.show().await?;
46+
/// Ok(())
47+
/// }
48+
/// ```
3149
#[derive(Debug)]
3250
pub struct $TABLE_FUNCTION_NAME {}
3351

@@ -157,13 +175,73 @@ define_tpch_udtf_provider!(
157175
tpchgen_arrow::RegionArrow
158176
);
159177

178+
/// Registers all the TPCH UDTFs in the given session context.
179+
pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
180+
ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
181+
ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
182+
ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
183+
ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
184+
ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
185+
ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
186+
ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
187+
ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));
188+
189+
Ok(())
190+
}
191+
160192
#[cfg(test)]
161193
mod tests {
162194
use super::*;
163195
use datafusion::execution::context::SessionContext;
164196

165197
#[tokio::test]
166-
async fn test_tpch_functions() -> Result<()> {
198+
async fn test_register_all_tpch_functions() -> Result<()> {
199+
let ctx = SessionContext::new();
200+
201+
// Register all the UDTFs.
202+
register_tpch_udtfs(&ctx)?;
203+
204+
// Test all the UDTFs, the constants were computed using the tpchgen library
205+
// and the expected values are the number of rows and columns for each table.
206+
let test_cases = vec![
207+
(TpchNation::name(), 25, 4),
208+
(TpchCustomer::name(), 150000, 8),
209+
(TpchOrders::name(), 1500000, 9),
210+
(TpchLineitem::name(), 6001215, 16),
211+
(TpchPart::name(), 200000, 9),
212+
(TpchPartsupp::name(), 800000, 5),
213+
(TpchSupplier::name(), 10000, 7),
214+
(TpchRegion::name(), 5, 3),
215+
];
216+
217+
for (function, expected_rows, expected_columns) in test_cases {
218+
let df = ctx
219+
.sql(&format!("SELECT * FROM {}(1.0)", function))
220+
.await?
221+
.collect()
222+
.await?;
223+
224+
assert_eq!(df.len(), 1);
225+
assert_eq!(
226+
df[0].num_rows(),
227+
expected_rows,
228+
"{}: {}",
229+
function,
230+
expected_rows
231+
);
232+
assert_eq!(
233+
df[0].num_columns(),
234+
expected_columns,
235+
"{}: {}",
236+
function,
237+
expected_columns
238+
);
239+
}
240+
Ok(())
241+
}
242+
243+
#[tokio::test]
244+
async fn test_register_individual_tpch_functions() -> Result<()> {
167245
let ctx = SessionContext::new();
168246

169247
// Register all the UDTFs.

0 commit comments

Comments
 (0)