@@ -5,15 +5,36 @@ use datafusion::{
55 catalog:: { MemTable , TableProvider } ,
66} ;
77
8+ use std:: fs;
9+
10+ use tokio:: sync:: OnceCell ;
11+
12+ use arrow:: record_batch:: RecordBatch ;
13+ use parquet:: { arrow:: arrow_writer:: ArrowWriter , file:: properties:: WriterProperties } ;
14+ use tpchgen:: generators:: {
15+ CustomerGenerator , LineItemGenerator , NationGenerator , OrderGenerator , PartGenerator ,
16+ PartSuppGenerator , RegionGenerator , SupplierGenerator ,
17+ } ;
18+ use tpchgen_arrow:: {
19+ CustomerArrow , LineItemArrow , NationArrow , OrderArrow , PartArrow , PartSuppArrow , RegionArrow ,
20+ SupplierArrow ,
21+ } ;
22+
23+ const QUERIES_DIR : & str = "tests/fixtures/tpch/queries" ;
24+ pub const DATA_DIR : & str = "tests/fixtures/tpch/data" ; // mirroed in .gitignore
25+ pub const NUM_QUERIES : u8 = 22 ; // number of queries in the TPCH benchmark numbered from 1 to 22
26+
27+ const SCALE_FACTOR : f64 = 0.001 ;
28+
829pub fn tpch_table ( name : & str ) -> Arc < dyn TableProvider > {
930 let schema = Arc :: new ( get_tpch_table_schema ( name) ) ;
1031 Arc :: new ( MemTable :: try_new ( schema, vec ! [ ] ) . unwrap ( ) )
1132}
1233
1334pub fn tpch_query ( num : u8 ) -> String {
14- // read the query from the test/tpch/ queries/ directory and return it
15- let query_path = format ! ( "testing/tpch/queries/ q{}.sql" , num) ;
16- std :: fs:: read_to_string ( query_path)
35+ // read the query from the queries directory in the fixtures dir and return it
36+ let query_path = format ! ( "{}/ q{}.sql" , QUERIES_DIR , num) ;
37+ fs:: read_to_string ( query_path)
1738 . unwrap_or_else ( |_| panic ! ( "Failed to read TPCH query file: q{}.sql" , num) )
1839 . trim ( )
1940 . to_string ( )
@@ -113,3 +134,61 @@ pub fn get_tpch_table_schema(table: &str) -> Schema {
113134 _ => unimplemented ! ( ) ,
114135 }
115136}
137+
138+ // generate_table creates a parquet file in the DATA_DIR directory from an arrow RecordBatch row
139+ // source.
140+ fn generate_table < A > ( mut data_source : A , table_name : & str ) -> Result < ( ) , Box < dyn std:: error:: Error > >
141+ where
142+ A : Iterator < Item = RecordBatch > ,
143+ {
144+ let output_path = format ! ( "{}/{}.parquet" , DATA_DIR , table_name) ;
145+
146+ if let Some ( first_batch) = data_source. next ( ) {
147+ let file = fs:: File :: create ( & output_path) ?;
148+ let props = WriterProperties :: builder ( ) . build ( ) ;
149+ let mut writer = ArrowWriter :: try_new ( file, first_batch. schema ( ) , Some ( props) ) ?;
150+
151+ writer. write ( & first_batch) ?;
152+
153+ while let Some ( batch) = data_source. next ( ) {
154+ writer. write ( & batch) ?;
155+ }
156+
157+ writer. close ( ) ?;
158+ }
159+
160+ println ! ( "Generated {} table: {}" , table_name, output_path) ;
161+ Ok ( ( ) )
162+ }
163+
164+ macro_rules! must_generate_tpch_table {
165+ ( $generator: ident, $arrow: ident, $name: literal) => {
166+ generate_table(
167+ // TODO: Consider adjusting the partitions and batch sizes.
168+ $arrow:: new( $generator:: new( SCALE_FACTOR , 1 , 1 ) ) . with_batch_size( 1000 ) ,
169+ $name,
170+ )
171+ . expect( concat!( "Failed to generate " , $name, " table" ) ) ;
172+ } ;
173+ }
174+
175+ // INIT_TPCH_TABLES is ensures that TPC-H tables are generated only once.
176+ static INIT_TPCH_TABLES : OnceCell < ( ) > = OnceCell :: const_new ( ) ;
177+
178+ // generate_tpch_data generates all TPC-H tables in the DATA_DIR directory.
179+ pub async fn generate_tpch_data ( ) {
180+ INIT_TPCH_TABLES
181+ . get_or_init ( || async {
182+ fs:: create_dir_all ( DATA_DIR ) . expect ( "Failed to create data directory" ) ;
183+
184+ must_generate_tpch_table ! ( RegionGenerator , RegionArrow , "region" ) ;
185+ must_generate_tpch_table ! ( NationGenerator , NationArrow , "nation" ) ;
186+ must_generate_tpch_table ! ( CustomerGenerator , CustomerArrow , "customer" ) ;
187+ must_generate_tpch_table ! ( SupplierGenerator , SupplierArrow , "supplier" ) ;
188+ must_generate_tpch_table ! ( PartGenerator , PartArrow , "part" ) ;
189+ must_generate_tpch_table ! ( PartSuppGenerator , PartSuppArrow , "partsupp" ) ;
190+ must_generate_tpch_table ! ( OrderGenerator , OrderArrow , "orders" ) ;
191+ must_generate_tpch_table ! ( LineItemGenerator , LineItemArrow , "lineitem" ) ;
192+ } )
193+ . await ;
194+ }
0 commit comments