Skip to content

Commit 5326433

Browse files
committed
fix: remove extra complexity from the tablegen logic defer to COPY for writing to disk
1 parent b94a714 commit 5326433

File tree

2 files changed

+76
-88
lines changed

2 files changed

+76
-88
lines changed

examples/tpchgen.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//! Example of using the datafusion-tpch extension to generate TPCH datasets
2+
//! on the the fly in datafusion.
3+
4+
use datafusion::prelude::{SessionConfig, SessionContext};
5+
use datafusion_tpch::register_tpch_udtf;
6+
7+
#[tokio::main]
8+
async fn main() -> datafusion::error::Result<()> {
9+
let ctx = SessionContext::new_with_config(SessionConfig::new().with_information_schema(true));
10+
register_tpch_udtf(&ctx);
11+
12+
let sql_df = ctx.sql(&format!("SELECT * FROM tpch(1.0);")).await?;
13+
sql_df.show().await?;
14+
15+
let sql_df = ctx.sql(&format!("SHOW TABLES;")).await?;
16+
sql_df.show().await?;
17+
18+
let sql_df = ctx.sql(&format!("SELECT * FROM nation LIMIT 5;")).await?;
19+
sql_df.show().await?;
20+
21+
let sql_df = ctx.sql(&format!("SELECT * FROM partsupp LIMIT 5;")).await?;
22+
sql_df.show().await?;
23+
24+
let sql_df = ctx.sql(&format!("SELECT * FROM region LIMIT 5;")).await?;
25+
sql_df.show().await?;
26+
27+
let sql_df = ctx.sql(&format!("SELECT * FROM customer LIMIT 5;")).await?;
28+
sql_df.show().await?;
29+
30+
let sql_df = ctx.sql(&format!("SELECT * FROM orders LIMIT 5;")).await?;
31+
sql_df.show().await?;
32+
33+
let sql_df = ctx.sql(&format!("SELECT * FROM lineitem LIMIT 5;")).await?;
34+
sql_df.show().await?;
35+
36+
let sql_df = ctx.sql(&format!("SELECT * FROM part LIMIT 5;")).await?;
37+
sql_df.show().await?;
38+
Ok(())
39+
}

src/lib.rs

Lines changed: 37 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -226,86 +226,41 @@ impl TpchTables {
226226
provider: P,
227227
table_name: &str,
228228
scale_factor: f64,
229-
write_to_disk: bool,
230-
_path: &str,
231229
) -> Result<()> {
232-
// Short path when the table is generated in memory only.
233-
if !write_to_disk {
234-
let table = provider
235-
.call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?;
236-
self.ctx
237-
.register_table(TableReference::bare(table_name), table)?;
238-
return Ok(());
239-
}
230+
let table = provider
231+
.call(vec![Expr::Literal(ScalarValue::Float64(Some(scale_factor)))].as_slice())?;
232+
self.ctx
233+
.register_table(TableReference::bare(table_name), table)?;
240234

241235
Ok(())
242236
}
243237

244238
/// Build and register all TPCH tables in the session context.
245-
fn build_and_register_all_tables(
246-
&self,
247-
scale_factor: f64,
248-
write_to_disk: bool,
249-
path: &str,
250-
) -> Result<()> {
239+
fn build_and_register_all_tables(&self, scale_factor: f64) -> Result<()> {
251240
for &suffix in Self::TPCH_TABLE_NAMES {
252241
match suffix {
253-
"nation" => self.build_and_register_tpch_table(
254-
TpchNation {},
255-
suffix,
256-
scale_factor,
257-
write_to_disk,
258-
path,
259-
)?,
260-
"customer" => self.build_and_register_tpch_table(
261-
TpchCustomer {},
262-
suffix,
263-
scale_factor,
264-
write_to_disk,
265-
path,
266-
)?,
267-
"orders" => self.build_and_register_tpch_table(
268-
TpchOrders {},
269-
suffix,
270-
scale_factor,
271-
write_to_disk,
272-
path,
273-
)?,
274-
"lineitem" => self.build_and_register_tpch_table(
275-
TpchLineitem {},
276-
suffix,
277-
scale_factor,
278-
write_to_disk,
279-
path,
280-
)?,
281-
"part" => self.build_and_register_tpch_table(
282-
TpchPart {},
283-
suffix,
284-
scale_factor,
285-
write_to_disk,
286-
path,
287-
)?,
288-
"partsupp" => self.build_and_register_tpch_table(
289-
TpchPartsupp {},
290-
suffix,
291-
scale_factor,
292-
write_to_disk,
293-
path,
294-
)?,
295-
"supplier" => self.build_and_register_tpch_table(
296-
TpchSupplier {},
297-
suffix,
298-
scale_factor,
299-
write_to_disk,
300-
path,
301-
)?,
302-
"region" => self.build_and_register_tpch_table(
303-
TpchRegion {},
304-
suffix,
305-
scale_factor,
306-
write_to_disk,
307-
path,
308-
)?,
242+
"nation" => {
243+
self.build_and_register_tpch_table(TpchNation {}, suffix, scale_factor)?
244+
}
245+
"customer" => {
246+
self.build_and_register_tpch_table(TpchCustomer {}, suffix, scale_factor)?
247+
}
248+
"orders" => {
249+
self.build_and_register_tpch_table(TpchOrders {}, suffix, scale_factor)?
250+
}
251+
"lineitem" => {
252+
self.build_and_register_tpch_table(TpchLineitem {}, suffix, scale_factor)?
253+
}
254+
"part" => self.build_and_register_tpch_table(TpchPart {}, suffix, scale_factor)?,
255+
"partsupp" => {
256+
self.build_and_register_tpch_table(TpchPartsupp {}, suffix, scale_factor)?
257+
}
258+
"supplier" => {
259+
self.build_and_register_tpch_table(TpchSupplier {}, suffix, scale_factor)?
260+
}
261+
"region" => {
262+
self.build_and_register_tpch_table(TpchRegion {}, suffix, scale_factor)?
263+
}
309264
_ => unreachable!("Unknown TPCH table suffix: {}", suffix), // Should not happen
310265
}
311266
}
@@ -325,27 +280,21 @@ impl TableFunctionImpl for TpchTables {
325280
/// The `call` method is the entry point for the UDTF and is called when the UDTF is
326281
/// invoked in a SQL query.
327282
///
328-
/// It takes a list of arguments, the scale factor, whether to generate the data on
329-
/// disk in parquet format and the path to the output files. If no path is provided,
330-
/// the data is generated in memory and we fallback to the `MemTable` provider.
283+
/// The UDF requires one argument, the scale factor, and allows a second optional
284+
/// argument which is a path on disk. If a path is specified, the data is flushed
285+
/// to disk from the generated memory table.
331286
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
332287
let scale_factor = match args.first() {
333288
Some(Expr::Literal(ScalarValue::Float64(Some(value)))) => *value,
334-
_ => return plan_err!("First argument must be a float literal."),
335-
};
336-
337-
let write_to_disk = match args.get(1) {
338-
Some(Expr::Literal(ScalarValue::Boolean(Some(value)))) => *value,
339-
_ => false,
340-
};
341-
342-
let path = match args.get(2) {
343-
Some(Expr::Literal(ScalarValue::Utf8(Some(value)))) => value.clone(),
344-
_ => "".to_string(),
289+
_ => {
290+
return plan_err!(
291+
"First argument must be a float literal that specifies the scale factor."
292+
);
293+
}
345294
};
346295

347296
// Register the TPCH tables in the session context.
348-
self.build_and_register_all_tables(scale_factor, write_to_disk, &path)?;
297+
self.build_and_register_all_tables(scale_factor)?;
349298

350299
// Create a table with the schema |table_name| and the data is just the
351300
// individual table names.
@@ -488,7 +437,7 @@ mod tests {
488437

489438
// Test the TPCH provider.
490439
let df = ctx
491-
.sql("SELECT * FROM tpch(1.0, false, '')")
440+
.sql("SELECT * FROM tpch(1.0, '')")
492441
.await?
493442
.collect()
494443
.await?;

0 commit comments

Comments
 (0)