|
| 1 | +use datafusion::arrow::compute::concat_batches; |
1 | 2 | use datafusion::catalog::{TableFunctionImpl, TableProvider}; |
2 | 3 | use datafusion::common::{Result, ScalarValue, plan_err}; |
3 | 4 | use datafusion::datasource::memory::MemTable; |
4 | 5 | use datafusion_expr::Expr; |
5 | 6 | use std::sync::Arc; |
6 | 7 | use tpchgen_arrow::{NationArrow, RecordBatchIterator}; |
7 | 8 |
|
8 | | -/// Table function that returns the TPCH nation table. |
9 | | -#[derive(Debug)] |
10 | | -pub struct TpchNationFunction {} |
| 9 | +/// Defines a table function provider and its implementation using [`tpchgen`] |
| 10 | +/// as the data source. |
| 11 | +macro_rules! define_tpch_udtf_provider { |
| 12 | + ($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $GENERATOR:ty, $ARROW_GENERATOR:ty) => { |
| 13 | + /// Tablle Function that returns the $TABLE_FUNCTION_NAME table. |
| 14 | + /// |
| 15 | + /// This function is a wrapper around the [`tpchgen`] library and builds |
| 16 | + /// a table provider that can be used in a DataFusion query. |
| 17 | + /// |
| 18 | + /// The expected arguments are a float literal for the scale factor, |
| 19 | + /// an i64 literal for the part, and an i64 literal for the number of parts. |
| 20 | + /// The second and third arguments are optional and will default to 1 |
| 21 | + /// for both values which tells the generator to generate all parts. |
| 22 | + /// |
| 23 | + /// # Examples |
| 24 | + /// |
| 25 | + /// -- This example generates TPCH nation data with scale factor 1.0. |
| 26 | + /// SELECT * FROM tpchgen_nation(1.0); |
| 27 | + /// |
| 28 | + /// -- This example generates TPCH order data with scale factor 10 and generates |
| 29 | + /// -- the second part with 5 parts. |
| 30 | + /// SELECT * FROM tpchgen_order(10.0, 2, 5); |
| 31 | + #[derive(Debug)] |
| 32 | + pub struct $TABLE_FUNCTION_NAME {} |
11 | 33 |
|
12 | | -impl TableFunctionImpl for TpchNationFunction { |
13 | | - fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> { |
14 | | - let Some(Expr::Literal(ScalarValue::Float64(Some(value)))) = args.get(0) else { |
15 | | - return plan_err!("First argument must be a float literal."); |
16 | | - }; |
| 34 | + impl $TABLE_FUNCTION_NAME { |
| 35 | + /// Returns the name of the table function. |
| 36 | + pub fn name() -> &'static str { |
| 37 | + stringify!($TABLE_FUNCTION_SQL_NAME) |
| 38 | + } |
| 39 | + } |
17 | 40 |
|
18 | | - // Init the table generator. |
19 | | - let tablegen = tpchgen::generators::NationGenerator::new(*value, 0, 0); |
| 41 | + impl TableFunctionImpl for $TABLE_FUNCTION_NAME { |
| 42 | + /// Implementation of the UDTF invocation for TPCH table generation |
| 43 | + /// using the [`tpchgen`] library. |
| 44 | + /// |
| 45 | + /// The first argument is a float literal that specifies the scale factor. |
| 46 | + /// The second argument is the part to generate. |
| 47 | + /// The third argument is the number of parts to generate. |
| 48 | + /// |
| 49 | + /// The second and third argument are optional and will default to 1 |
| 50 | + /// for both values which tells the generator to generate all parts. |
| 51 | + fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> { |
| 52 | + let Some(Expr::Literal(ScalarValue::Float64(Some(value)))) = args.get(0) else { |
| 53 | + return plan_err!("First argument must be a float literal."); |
| 54 | + }; |
20 | 55 |
|
21 | | - // Init the arrow provider. |
22 | | - let mut arrow_tablegen = NationArrow::new(tablegen); |
| 56 | + // Default values for part and num_parts. |
| 57 | + let part = 1; |
| 58 | + let num_parts = 1; |
23 | 59 |
|
24 | | - let batch = arrow_tablegen.next().unwrap(); |
| 60 | + // Check if we have more arguments `part` and `num_parts` respectively |
| 61 | + // and if they are i64 literals. |
| 62 | + if args.len() > 1 { |
| 63 | + // Check if the second argument and third arguments are i64 literals and |
| 64 | + // greater than 0. |
| 65 | + let Some(Expr::Literal(ScalarValue::Int64(Some(part)))) = args.get(1) else { |
| 66 | + return plan_err!("Second argument must be an i64 literal."); |
| 67 | + }; |
| 68 | + let Some(Expr::Literal(ScalarValue::Int64(Some(num_parts)))) = args.get(2) |
| 69 | + else { |
| 70 | + return plan_err!("Third argument must be an i64 literal."); |
| 71 | + }; |
| 72 | + if *part < 0 || *num_parts < 0 { |
| 73 | + return plan_err!("Second and third arguments must be greater than 0."); |
| 74 | + } |
| 75 | + } |
25 | 76 |
|
26 | | - // Build the memtable plan. |
27 | | - let provider = MemTable::try_new(arrow_tablegen.schema().clone(), vec![vec![batch]])?; |
| 77 | + // Init the table generator. |
| 78 | + let tablegen = <$GENERATOR>::new(*value, part, num_parts); |
28 | 79 |
|
29 | | - Ok(Arc::new(provider)) |
30 | | - } |
| 80 | + // Init the arrow provider. |
| 81 | + let mut arrow_tablegen = <$ARROW_GENERATOR>::new(tablegen); |
| 82 | + |
| 83 | + // The arrow provider is a batched generator with a default batch size of 8000 |
| 84 | + // so to build the full table we need to call `next` until it returns None. |
| 85 | + let mut batches = Vec::new(); |
| 86 | + while let Some(batch) = arrow_tablegen.next() { |
| 87 | + batches.push(batch); |
| 88 | + } |
| 89 | + // Use `concat_batches` to create a single batch from the vector of batches. |
| 90 | + // This is needed because the `MemTable` provider requires a single batch. |
| 91 | + // This is a bit of a hack, but it works. |
| 92 | + let batch = concat_batches(arrow_tablegen.schema(), &batches)?; |
| 93 | + |
| 94 | + // Build the memtable plan. |
| 95 | + let provider = |
| 96 | + MemTable::try_new(arrow_tablegen.schema().clone(), vec![vec![batch]])?; |
| 97 | + |
| 98 | + Ok(Arc::new(provider)) |
| 99 | + } |
| 100 | + } |
| 101 | + }; |
31 | 102 | } |
32 | 103 |
|
| 104 | +define_tpch_udtf_provider!( |
| 105 | + TpchNation, |
| 106 | + tpch_nation, |
| 107 | + tpchgen::generators::NationGenerator, |
| 108 | + NationArrow |
| 109 | +); |
| 110 | + |
| 111 | +define_tpch_udtf_provider!( |
| 112 | + TpchCustomer, |
| 113 | + tpch_customer, |
| 114 | + tpchgen::generators::CustomerGenerator, |
| 115 | + tpchgen_arrow::CustomerArrow |
| 116 | +); |
| 117 | + |
| 118 | +define_tpch_udtf_provider!( |
| 119 | + TpchOrders, |
| 120 | + tpch_orders, |
| 121 | + tpchgen::generators::OrderGenerator, |
| 122 | + tpchgen_arrow::OrderArrow |
| 123 | +); |
| 124 | + |
| 125 | +define_tpch_udtf_provider!( |
| 126 | + TpchLineitem, |
| 127 | + tpch_lineitem, |
| 128 | + tpchgen::generators::LineItemGenerator, |
| 129 | + tpchgen_arrow::LineItemArrow |
| 130 | +); |
| 131 | + |
| 132 | +define_tpch_udtf_provider!( |
| 133 | + TpchPart, |
| 134 | + tpch_part, |
| 135 | + tpchgen::generators::PartGenerator, |
| 136 | + tpchgen_arrow::PartArrow |
| 137 | +); |
| 138 | + |
| 139 | +define_tpch_udtf_provider!( |
| 140 | + TpchPartsupp, |
| 141 | + tpch_partsupp, |
| 142 | + tpchgen::generators::PartSuppGenerator, |
| 143 | + tpchgen_arrow::PartSuppArrow |
| 144 | +); |
| 145 | + |
| 146 | +define_tpch_udtf_provider!( |
| 147 | + TpchSupplier, |
| 148 | + tpch_supplier, |
| 149 | + tpchgen::generators::SupplierGenerator, |
| 150 | + tpchgen_arrow::SupplierArrow |
| 151 | +); |
| 152 | + |
| 153 | +define_tpch_udtf_provider!( |
| 154 | + TpchRegion, |
| 155 | + tpch_region, |
| 156 | + tpchgen::generators::RegionGenerator, |
| 157 | + tpchgen_arrow::RegionArrow |
| 158 | +); |
| 159 | + |
33 | 160 | #[cfg(test)] |
34 | 161 | mod tests { |
35 | 162 | use super::*; |
36 | 163 | use datafusion::execution::context::SessionContext; |
37 | 164 |
|
38 | 165 | #[tokio::test] |
39 | | - async fn test_tpchgen_function() -> Result<()> { |
| 166 | + async fn test_tpch_functions() -> Result<()> { |
40 | 167 | let ctx = SessionContext::new(); |
41 | | - ctx.register_udtf("tpchgen_nation", Arc::new(TpchNationFunction {})); |
42 | 168 |
|
43 | | - let df = ctx |
44 | | - .sql("SELECT * FROM tpchgen_nation(1.0)") |
45 | | - .await? |
46 | | - .collect() |
47 | | - .await?; |
| 169 | + // Register all the UDTFs. |
| 170 | + ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {})); |
| 171 | + ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {})); |
| 172 | + ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {})); |
| 173 | + ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {})); |
| 174 | + ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {})); |
| 175 | + ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {})); |
| 176 | + ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {})); |
| 177 | + ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {})); |
| 178 | + |
| 179 | + // Test all the UDTFs, the constants were computed using the tpchgen library |
| 180 | + // and the expected values are the number of rows and columns for each table. |
| 181 | + let test_cases = vec![ |
| 182 | + (TpchNation::name(), 25, 4), |
| 183 | + (TpchCustomer::name(), 150000, 8), |
| 184 | + (TpchOrders::name(), 1500000, 9), |
| 185 | + (TpchLineitem::name(), 6001215, 16), |
| 186 | + (TpchPart::name(), 200000, 9), |
| 187 | + (TpchPartsupp::name(), 800000, 5), |
| 188 | + (TpchSupplier::name(), 10000, 7), |
| 189 | + (TpchRegion::name(), 5, 3), |
| 190 | + ]; |
| 191 | + |
| 192 | + for (function, expected_rows, expected_columns) in test_cases { |
| 193 | + let df = ctx |
| 194 | + .sql(&format!("SELECT * FROM {}(1.0)", function)) |
| 195 | + .await? |
| 196 | + .collect() |
| 197 | + .await?; |
48 | 198 |
|
49 | | - assert_eq!(df.len(), 1); |
50 | | - assert_eq!(df[0].num_rows(), 25); |
51 | | - assert_eq!(df[0].num_columns(), 4); |
| 199 | + assert_eq!(df.len(), 1); |
| 200 | + assert_eq!( |
| 201 | + df[0].num_rows(), |
| 202 | + expected_rows, |
| 203 | + "{}: {}", |
| 204 | + function, |
| 205 | + expected_rows |
| 206 | + ); |
| 207 | + assert_eq!( |
| 208 | + df[0].num_columns(), |
| 209 | + expected_columns, |
| 210 | + "{}: {}", |
| 211 | + function, |
| 212 | + expected_columns |
| 213 | + ); |
| 214 | + } |
52 | 215 | Ok(()) |
53 | 216 | } |
54 | 217 | } |
0 commit comments