@@ -2,6 +2,7 @@ use datafusion::arrow::compute::concat_batches;
22use datafusion:: catalog:: { TableFunctionImpl , TableProvider } ;
33use datafusion:: common:: { Result , ScalarValue , plan_err} ;
44use datafusion:: datasource:: memory:: MemTable ;
5+ use datafusion:: prelude:: SessionContext ;
56use datafusion_expr:: Expr ;
67use std:: sync:: Arc ;
78use tpchgen_arrow:: RecordBatchIterator ;
@@ -10,24 +11,41 @@ use tpchgen_arrow::RecordBatchIterator;
1011/// as the data source.
1112macro_rules! define_tpch_udtf_provider {
1213 ( $TABLE_FUNCTION_NAME: ident, $TABLE_FUNCTION_SQL_NAME: ident, $GENERATOR: ty, $ARROW_GENERATOR: ty) => {
13- /// Tablle Function that returns the $TABLE_FUNCTION_NAME table.
14- ///
15- /// This function is a wrapper around the [`tpchgen`] library and builds
16- /// a table provider that can be used in a DataFusion query.
14+ #[ doc = concat!(
15+ "A table function that generates the `" ,
16+ stringify!( $TABLE_FUNCTION_SQL_NAME) ,
17+ "` table using the `tpchgen` library."
18+ ) ]
1719 ///
1820 /// The expected arguments are a float literal for the scale factor,
1921 /// an i64 literal for the part, and an i64 literal for the number of parts.
2022 /// The second and third arguments are optional and will default to 1
2123 /// for both values which tells the generator to generate all parts.
2224 ///
2325 /// # Examples
24- ///
25- /// -- This example generates TPCH nation data with scale factor 1.0.
26- /// SELECT * FROM tpchgen_nation(1.0);
27- ///
28- /// -- This example generates TPCH order data with scale factor 10 and generates
29- /// -- the second part with 5 parts.
30- /// SELECT * FROM tpchgen_order(10.0, 2, 5);
26+ /// ```
27+ /// #[tokio::main]
28+ /// async fn main() -> Result<()> {
29+ /// // create local execution context
30+ /// let ctx = SessionContext::new();
31+
32+ /// // Register all the UDTFs.
33+ /// ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
34+ /// ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
35+ /// ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
36+ /// ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
37+ /// ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
38+ /// ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
39+ /// ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
40+ /// ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));
41+ /// // Generate the nation table with a scale factor of 1.
42+ /// let df = ctx
43+ /// .sql(format!("SELECT * FROM tpch_nation(1.0);").as_str())
44+ /// .await?;
45+ /// df.show().await?;
46+ /// Ok(())
47+ /// }
48+ /// ```
3149 #[ derive( Debug ) ]
3250 pub struct $TABLE_FUNCTION_NAME { }
3351
@@ -157,13 +175,73 @@ define_tpch_udtf_provider!(
157175 tpchgen_arrow:: RegionArrow
158176) ;
159177
178+ /// Registers all the TPCH UDTFs in the given session context.
179+ pub fn register_tpch_udtfs ( ctx : & SessionContext ) -> Result < ( ) > {
180+ ctx. register_udtf ( TpchNation :: name ( ) , Arc :: new ( TpchNation { } ) ) ;
181+ ctx. register_udtf ( TpchCustomer :: name ( ) , Arc :: new ( TpchCustomer { } ) ) ;
182+ ctx. register_udtf ( TpchOrders :: name ( ) , Arc :: new ( TpchOrders { } ) ) ;
183+ ctx. register_udtf ( TpchLineitem :: name ( ) , Arc :: new ( TpchLineitem { } ) ) ;
184+ ctx. register_udtf ( TpchPart :: name ( ) , Arc :: new ( TpchPart { } ) ) ;
185+ ctx. register_udtf ( TpchPartsupp :: name ( ) , Arc :: new ( TpchPartsupp { } ) ) ;
186+ ctx. register_udtf ( TpchSupplier :: name ( ) , Arc :: new ( TpchSupplier { } ) ) ;
187+ ctx. register_udtf ( TpchRegion :: name ( ) , Arc :: new ( TpchRegion { } ) ) ;
188+
189+ Ok ( ( ) )
190+ }
191+
160192#[ cfg( test) ]
161193mod tests {
162194 use super :: * ;
163195 use datafusion:: execution:: context:: SessionContext ;
164196
165197 #[ tokio:: test]
166- async fn test_tpch_functions ( ) -> Result < ( ) > {
198+ async fn test_register_all_tpch_functions ( ) -> Result < ( ) > {
199+ let ctx = SessionContext :: new ( ) ;
200+
201+ // Register all the UDTFs.
202+ register_tpch_udtfs ( & ctx) ?;
203+
204+ // Test all the UDTFs, the constants were computed using the tpchgen library
205+ // and the expected values are the number of rows and columns for each table.
206+ let test_cases = vec ! [
207+ ( TpchNation :: name( ) , 25 , 4 ) ,
208+ ( TpchCustomer :: name( ) , 150000 , 8 ) ,
209+ ( TpchOrders :: name( ) , 1500000 , 9 ) ,
210+ ( TpchLineitem :: name( ) , 6001215 , 16 ) ,
211+ ( TpchPart :: name( ) , 200000 , 9 ) ,
212+ ( TpchPartsupp :: name( ) , 800000 , 5 ) ,
213+ ( TpchSupplier :: name( ) , 10000 , 7 ) ,
214+ ( TpchRegion :: name( ) , 5 , 3 ) ,
215+ ] ;
216+
217+ for ( function, expected_rows, expected_columns) in test_cases {
218+ let df = ctx
219+ . sql ( & format ! ( "SELECT * FROM {}(1.0)" , function) )
220+ . await ?
221+ . collect ( )
222+ . await ?;
223+
224+ assert_eq ! ( df. len( ) , 1 ) ;
225+ assert_eq ! (
226+ df[ 0 ] . num_rows( ) ,
227+ expected_rows,
228+ "{}: {}" ,
229+ function,
230+ expected_rows
231+ ) ;
232+ assert_eq ! (
233+ df[ 0 ] . num_columns( ) ,
234+ expected_columns,
235+ "{}: {}" ,
236+ function,
237+ expected_columns
238+ ) ;
239+ }
240+ Ok ( ( ) )
241+ }
242+
243+ #[ tokio:: test]
244+ async fn test_register_individual_tpch_functions ( ) -> Result < ( ) > {
167245 let ctx = SessionContext :: new ( ) ;
168246
169247 // Register all the UDTFs.
0 commit comments