@@ -28,7 +28,7 @@ use datafusion::arrow::record_batch::RecordBatch;
2828use datafusion:: arrow:: util:: pretty:: { self , pretty_format_batches} ;
2929use datafusion:: common:: instant:: Instant ;
3030use datafusion:: common:: utils:: get_available_parallelism;
31- use datafusion:: common:: { DEFAULT_CSV_EXTENSION , DEFAULT_PARQUET_EXTENSION } ;
31+ use datafusion:: common:: { exec_err , DEFAULT_CSV_EXTENSION , DEFAULT_PARQUET_EXTENSION } ;
3232use datafusion:: datasource:: file_format:: csv:: CsvFormat ;
3333use datafusion:: datasource:: file_format:: parquet:: ParquetFormat ;
3434use datafusion:: datasource:: file_format:: FileFormat ;
@@ -50,6 +50,7 @@ use datafusion_distributed::{
5050 DistributedSessionBuilderContext ,
5151} ;
5252use log:: info;
53+ use std:: fs;
5354use std:: path:: PathBuf ;
5455use std:: sync:: Arc ;
5556use structopt:: StructOpt ;
@@ -77,8 +78,8 @@ pub struct RunOpt {
7778 common : CommonOpt ,
7879
7980 /// Path to data files
80- #[ structopt( parse( from_os_str) , required = true , short = "p" , long = "path" ) ]
81- path : PathBuf ,
81+ #[ structopt( parse( from_os_str) , short = "p" , long = "path" ) ]
82+ path : Option < PathBuf > ,
8283
8384 /// File format: `csv` or `parquet`
8485 #[ structopt( short = "f" , long = "format" , default_value = "parquet" ) ]
@@ -211,7 +212,7 @@ impl RunOpt {
211212 } ;
212213
213214 self . output_path
214- . get_or_insert_with ( || self . path . join ( "results.json" ) ) ;
215+ . get_or_insert ( self . get_path ( ) ? . join ( "results.json" ) ) ;
215216 let mut benchmark_run = BenchmarkRun :: new ( ) ;
216217
217218 for query_id in query_range {
@@ -335,8 +336,25 @@ impl RunOpt {
335336 Ok ( result)
336337 }
337338
339+ fn get_path ( & self ) -> Result < PathBuf > {
340+ if let Some ( path) = & self . path {
341+ return Ok ( path. clone ( ) ) ;
342+ }
343+ let crate_path = PathBuf :: from ( env ! ( "CARGO_MANIFEST_DIR" ) ) ;
344+ let data_path = crate_path. join ( "data" ) ;
345+ let entries = fs:: read_dir ( & data_path) ?. collect :: < Result < Vec < _ > , _ > > ( ) ?;
346+ if entries. is_empty ( ) {
347+ exec_err ! ( "No TPCH dataset present in '{data_path:?}'. Generate one with ./benchmarks/gen-tpch.sh" )
348+ } else if entries. len ( ) == 1 {
349+ Ok ( entries[ 0 ] . path ( ) )
350+ } else {
351+ exec_err ! ( "Multiple TPCH datasets present in '{data_path:?}'. One must be selected with --path" )
352+ }
353+ }
354+
338355 async fn get_table ( & self , ctx : & SessionContext , table : & str ) -> Result < Arc < dyn TableProvider > > {
339- let path = self . path . to_str ( ) . unwrap ( ) ;
356+ let path = self . get_path ( ) ?;
357+ let path = path. to_str ( ) . unwrap ( ) ;
340358 let table_format = self . file_format . as_str ( ) ;
341359 let target_partitions = self . partitions ( ) ;
342360
0 commit comments