11use datafusion:: arrow:: compute:: concat_batches;
2+ use datafusion:: arrow:: datatypes:: Schema ;
23use datafusion:: catalog:: { TableFunctionImpl , TableProvider } ;
34use datafusion:: common:: { Result , ScalarValue , plan_err} ;
45use datafusion:: datasource:: memory:: MemTable ;
56use datafusion:: prelude:: SessionContext ;
7+ use datafusion:: sql:: TableReference ;
68use datafusion_expr:: Expr ;
9+ use std:: fmt:: Debug ;
710use std:: sync:: Arc ;
811use tpchgen_arrow:: RecordBatchIterator ;
912
1013/// Defines a table function provider and its implementation using [`tpchgen`]
1114/// as the data source.
1215macro_rules! define_tpch_udtf_provider {
1316 ( $TABLE_FUNCTION_NAME: ident, $TABLE_FUNCTION_SQL_NAME: ident, $GENERATOR: ty, $ARROW_GENERATOR: ty) => {
14- #[ doc = concat!(
15- "A table function that generates the `" ,
16- stringify!( $TABLE_FUNCTION_SQL_NAME) ,
17- "` table using the `tpchgen` library."
18- ) ]
17+ #[ doc = concat!( "A table function that generates the `" , stringify!( $TABLE_FUNCTION_SQL_NAME) , "` table using the `tpchgen` library." ) ]
1918 ///
2019 /// The expected arguments are a float literal for the scale factor,
2120 /// an i64 literal for the part, and an i64 literal for the number of parts.
@@ -59,6 +58,19 @@ macro_rules! define_tpch_udtf_provider {
5958 pub fn name( ) -> & ' static str {
6059 stringify!( $TABLE_FUNCTION_SQL_NAME)
6160 }
61+
62+ /// Returns the name of the table generated by the table function
63+ /// when used in a SQL query.
64+ pub fn table_name( ) -> & ' static str {
65+ stringify!( $TABLE_FUNCTION_SQL_NAME)
66+ . strip_prefix( "tpch_" )
67+ . unwrap_or_else( || {
68+ panic!(
69+ "Table function name {} does not start with tpch_" ,
70+ stringify!( $TABLE_FUNCTION_SQL_NAME)
71+ )
72+ } )
73+ }
6274 }
6375
6476 impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
@@ -194,6 +206,122 @@ pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
194206 Ok ( ( ) )
195207}
196208
209+ /// Table function provider for TPCH tables.
210+ struct TpchTables {
211+ ctx : SessionContext ,
212+ }
213+
214+ impl TpchTables {
215+ const TPCH_TABLE_NAMES : & [ & str ] = & [
216+ "nation" , "customer" , "orders" , "lineitem" , "part" , "partsupp" , "supplier" , "region" ,
217+ ] ;
218+ /// Creates a new TPCH table provider.
219+ pub fn new ( ctx : SessionContext ) -> Self {
220+ Self { ctx }
221+ }
222+
223+ /// Build and register a TPCH table by it's table function provider.
224+ fn build_and_register_tpch_table < P : TableFunctionImpl > (
225+ & self ,
226+ provider : P ,
227+ table_name : & str ,
228+ scale_factor : f64 ,
229+ ) -> Result < ( ) > {
230+ let table = provider
231+ . call ( vec ! [ Expr :: Literal ( ScalarValue :: Float64 ( Some ( scale_factor) ) ) ] . as_slice ( ) ) ?;
232+ self . ctx
233+ . register_table ( TableReference :: bare ( table_name) , table) ?;
234+
235+ Ok ( ( ) )
236+ }
237+
238+ /// Build and register all TPCH tables in the session context.
239+ fn build_and_register_all_tables ( & self , scale_factor : f64 ) -> Result < ( ) > {
240+ for & suffix in Self :: TPCH_TABLE_NAMES {
241+ match suffix {
242+ "nation" => {
243+ self . build_and_register_tpch_table ( TpchNation { } , suffix, scale_factor) ?
244+ }
245+ "customer" => {
246+ self . build_and_register_tpch_table ( TpchCustomer { } , suffix, scale_factor) ?
247+ }
248+ "orders" => {
249+ self . build_and_register_tpch_table ( TpchOrders { } , suffix, scale_factor) ?
250+ }
251+ "lineitem" => {
252+ self . build_and_register_tpch_table ( TpchLineitem { } , suffix, scale_factor) ?
253+ }
254+ "part" => self . build_and_register_tpch_table ( TpchPart { } , suffix, scale_factor) ?,
255+ "partsupp" => {
256+ self . build_and_register_tpch_table ( TpchPartsupp { } , suffix, scale_factor) ?
257+ }
258+ "supplier" => {
259+ self . build_and_register_tpch_table ( TpchSupplier { } , suffix, scale_factor) ?
260+ }
261+ "region" => {
262+ self . build_and_register_tpch_table ( TpchRegion { } , suffix, scale_factor) ?
263+ }
264+ _ => unreachable ! ( "Unknown TPCH table suffix: {}" , suffix) , // Should not happen
265+ }
266+ }
267+ Ok ( ( ) )
268+ }
269+ }
270+
271+ // Implement the `TableProvider` trait for the `TpchTableProvider`, we need
272+ // to do it manually because the `SessionContext` does not implement it.
273+ impl Debug for TpchTables {
274+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
275+ write ! ( f, "TpchTableProvider" )
276+ }
277+ }
278+
279+ impl TableFunctionImpl for TpchTables {
280+ /// The `call` method is the entry point for the UDTF and is called when the UDTF is
281+ /// invoked in a SQL query.
282+ ///
283+ /// The UDF requires one argument, the scale factor, and allows a second optional
284+ /// argument which is a path on disk. If a path is specified, the data is flushed
285+ /// to disk from the generated memory table.
286+ fn call ( & self , args : & [ Expr ] ) -> Result < Arc < dyn TableProvider > > {
287+ let scale_factor = match args. first ( ) {
288+ Some ( Expr :: Literal ( ScalarValue :: Float64 ( Some ( value) ) ) ) => * value,
289+ _ => {
290+ return plan_err ! (
291+ "First argument must be a float literal that specifies the scale factor."
292+ ) ;
293+ }
294+ } ;
295+
296+ // Register the TPCH tables in the session context.
297+ self . build_and_register_all_tables ( scale_factor) ?;
298+
299+ // Create a table with the schema |table_name| and the data is just the
300+ // individual table names.
301+ let schema = Schema :: new ( vec ! [ datafusion:: arrow:: datatypes:: Field :: new(
302+ "table_name" ,
303+ datafusion:: arrow:: datatypes:: DataType :: Utf8 ,
304+ false ,
305+ ) ] ) ;
306+ let batch = datafusion:: arrow:: record_batch:: RecordBatch :: try_new (
307+ Arc :: new ( schema. clone ( ) ) ,
308+ vec ! [ Arc :: new( datafusion:: arrow:: array:: StringArray :: from( vec![
309+ "nation" , "customer" , "orders" , "lineitem" , "part" , "partsupp" , "supplier" ,
310+ "region" ,
311+ ] ) ) ] ,
312+ ) ?;
313+ let mem_table = MemTable :: try_new ( Arc :: new ( schema) , vec ! [ vec![ batch] ] ) ?;
314+
315+ Ok ( Arc :: new ( mem_table) )
316+ }
317+ }
318+
319+ /// Register the `tpch` table function.
320+ pub fn register_tpch_udtf ( ctx : & SessionContext ) {
321+ let tpch_udtf = TpchTables :: new ( ctx. clone ( ) ) ;
322+ ctx. register_udtf ( "tpch" , Arc :: new ( tpch_udtf) ) ;
323+ }
324+
197325#[ cfg( test) ]
198326mod tests {
199327 use super :: * ;
@@ -203,12 +331,15 @@ mod tests {
203331 async fn test_register_all_tpch_functions ( ) -> Result < ( ) > {
204332 let ctx = SessionContext :: new ( ) ;
205333
334+ let tpch_tbl_fn = TpchTables :: new ( ctx. clone ( ) ) ;
335+ ctx. register_udtf ( "tcph" , Arc :: new ( tpch_tbl_fn) ) ;
336+
206337 // Register all the UDTFs.
207338 register_tpch_udtfs ( & ctx) ?;
208339
209340 // Test all the UDTFs, the constants were computed using the tpchgen library
210341 // and the expected values are the number of rows and columns for each table.
211- let test_cases = vec ! [
342+ let expected_tables = vec ! [
212343 ( TpchNation :: name( ) , 25 , 4 ) ,
213344 ( TpchCustomer :: name( ) , 150000 , 8 ) ,
214345 ( TpchOrders :: name( ) , 1500000 , 9 ) ,
@@ -219,7 +350,7 @@ mod tests {
219350 ( TpchRegion :: name( ) , 5 , 3 ) ,
220351 ] ;
221352
222- for ( function, expected_rows, expected_columns) in test_cases {
353+ for ( function, expected_rows, expected_columns) in expected_tables {
223354 let df = ctx
224355 . sql ( & format ! ( "SELECT * FROM {}(1.0)" , function) )
225356 . await ?
@@ -261,7 +392,7 @@ mod tests {
261392
262393 // Test all the UDTFs, the constants were computed using the tpchgen library
263394 // and the expected values are the number of rows and columns for each table.
264- let test_cases = vec ! [
395+ let expected_tables = vec ! [
265396 ( TpchNation :: name( ) , 25 , 4 ) ,
266397 ( TpchCustomer :: name( ) , 150000 , 8 ) ,
267398 ( TpchOrders :: name( ) , 1500000 , 9 ) ,
@@ -272,7 +403,7 @@ mod tests {
272403 ( TpchRegion :: name( ) , 5 , 3 ) ,
273404 ] ;
274405
275- for ( function, expected_rows, expected_columns) in test_cases {
406+ for ( function, expected_rows, expected_columns) in expected_tables {
276407 let df = ctx
277408 . sql ( & format ! ( "SELECT * FROM {}(1.0)" , function) )
278409 . await ?
@@ -297,4 +428,58 @@ mod tests {
297428 }
298429 Ok ( ( ) )
299430 }
431+
432+ #[ tokio:: test]
433+ async fn test_register_tpch_provider ( ) -> Result < ( ) > {
434+ let ctx = SessionContext :: new ( ) ;
435+
436+ register_tpch_udtf ( & ctx) ;
437+
438+ // Test the TPCH provider.
439+ let df = ctx
440+ . sql ( "SELECT * FROM tpch(1.0, '')" )
441+ . await ?
442+ . collect ( )
443+ . await ?;
444+
445+ assert_eq ! ( df. len( ) , 1 ) ;
446+ assert_eq ! ( df[ 0 ] . num_rows( ) , 8 ) ;
447+ assert_eq ! ( df[ 0 ] . num_columns( ) , 1 ) ;
448+
449+ let expected_tables = vec ! [
450+ ( TpchNation :: table_name( ) , 25 , 4 ) ,
451+ ( TpchCustomer :: table_name( ) , 150000 , 8 ) ,
452+ ( TpchOrders :: table_name( ) , 1500000 , 9 ) ,
453+ ( TpchLineitem :: table_name( ) , 6001215 , 16 ) ,
454+ ( TpchPart :: table_name( ) , 200000 , 9 ) ,
455+ ( TpchPartsupp :: table_name( ) , 800000 , 5 ) ,
456+ ( TpchSupplier :: table_name( ) , 10000 , 7 ) ,
457+ ( TpchRegion :: table_name( ) , 5 , 3 ) ,
458+ ] ;
459+
460+ for ( function, expected_rows, expected_columns) in expected_tables {
461+ let df = ctx
462+ . sql ( & format ! ( "SELECT * FROM {}" , function) )
463+ . await ?
464+ . collect ( )
465+ . await ?;
466+
467+ assert_eq ! ( df. len( ) , 1 ) ;
468+ assert_eq ! (
469+ df[ 0 ] . num_rows( ) ,
470+ expected_rows,
471+ "{}: {}" ,
472+ function,
473+ expected_rows
474+ ) ;
475+ assert_eq ! (
476+ df[ 0 ] . num_columns( ) ,
477+ expected_columns,
478+ "{}: {}" ,
479+ function,
480+ expected_columns
481+ ) ;
482+ }
483+ Ok ( ( ) )
484+ }
300485}
0 commit comments