1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: path:: PathBuf ;
19- use std:: sync:: Arc ;
20-
2118use super :: {
2219 get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_QUERY_END_ID ,
2320 TPCH_QUERY_START_ID , TPCH_TABLES ,
2421} ;
22+ use async_trait:: async_trait;
23+ use std:: path:: PathBuf ;
24+ use std:: sync:: Arc ;
2525
2626use datafusion:: arrow:: record_batch:: RecordBatch ;
2727use datafusion:: arrow:: util:: pretty:: { self , pretty_format_batches} ;
@@ -35,12 +35,15 @@ use datafusion::datasource::listing::{
3535 ListingOptions , ListingTable , ListingTableConfig , ListingTableUrl ,
3636} ;
3737use datafusion:: datasource:: { MemTable , TableProvider } ;
38- use datafusion:: error:: Result ;
38+ use datafusion:: error:: { DataFusionError , Result } ;
39+ use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
3940use datafusion:: physical_plan:: display:: DisplayableExecutionPlan ;
4041use datafusion:: physical_plan:: { collect, displayable} ;
4142use datafusion:: prelude:: * ;
4243
4344use crate :: util:: { print_memory_stats, BenchmarkRun , CommonOpt , QueryResult } ;
45+ use datafusion_distributed:: test_utils:: localhost:: start_localhost_context;
46+ use datafusion_distributed:: { DistributedPhysicalOptimizerRule , SessionBuilder } ;
4447use log:: info;
4548use structopt:: StructOpt ;
4649
@@ -96,26 +99,55 @@ pub struct RunOpt {
9699 /// The tables should have been created with the `--sort` option for this to have any effect.
97100 #[ structopt( short = "t" , long = "sorted" ) ]
98101 sorted : bool ,
102+
103+ /// Mark the first column of each table as sorted in ascending order.
104+ /// The tables should have been created with the `--sort` option for this to have any effect.
105+ #[ structopt( long = "ppt" ) ]
106+ partitions_per_task : Option < usize > ,
107+ }
108+
109+ #[ async_trait]
110+ impl SessionBuilder for RunOpt {
111+ fn session_state_builder (
112+ & self ,
113+ builder : SessionStateBuilder ,
114+ ) -> Result < SessionStateBuilder , DataFusionError > {
115+ let mut config = self
116+ . common
117+ . config ( ) ?
118+ . with_collect_statistics ( !self . disable_statistics ) ;
119+ config. options_mut ( ) . optimizer . prefer_hash_join = self . prefer_hash_join ;
120+ let rt_builder = self . common . runtime_env_builder ( ) ?;
121+
122+ let mut rule = DistributedPhysicalOptimizerRule :: new ( ) ;
123+ if let Some ( ppt) = self . partitions_per_task {
124+ rule = rule. with_maximum_partitions_per_task ( ppt) ;
125+ }
126+ Ok ( builder
127+ . with_config ( config)
128+ . with_physical_optimizer_rule ( Arc :: new ( rule) )
129+ . with_runtime_env ( rt_builder. build_arc ( ) ?) )
130+ }
131+
132+ async fn session_context (
133+ & self ,
134+ ctx : SessionContext ,
135+ ) -> std:: result:: Result < SessionContext , DataFusionError > {
136+ self . register_tables ( & ctx) . await ?;
137+ Ok ( ctx)
138+ }
99139}
100140
101141impl RunOpt {
102142 pub async fn run ( self ) -> Result < ( ) > {
143+ let ( ctx, _guard) = start_localhost_context ( [ 50051 ] , self . clone ( ) ) . await ;
103144 println ! ( "Running benchmarks with the following options: {self:?}" ) ;
104145 let query_range = match self . query {
105146 Some ( query_id) => query_id..=query_id,
106147 None => TPCH_QUERY_START_ID ..=TPCH_QUERY_END_ID ,
107148 } ;
108149
109150 let mut benchmark_run = BenchmarkRun :: new ( ) ;
110- let mut config = self
111- . common
112- . config ( ) ?
113- . with_collect_statistics ( !self . disable_statistics ) ;
114- config. options_mut ( ) . optimizer . prefer_hash_join = self . prefer_hash_join ;
115- let rt_builder = self . common . runtime_env_builder ( ) ?;
116- let ctx = SessionContext :: new_with_config_rt ( config, rt_builder. build_arc ( ) ?) ;
117- // register tables
118- self . register_tables ( & ctx) . await ?;
119151
120152 for query_id in query_range {
121153 benchmark_run. start_new_case ( & format ! ( "Query {query_id}" ) ) ;
@@ -368,6 +400,7 @@ mod tests {
368400 disable_statistics : false ,
369401 prefer_hash_join : true ,
370402 sorted : false ,
403+ partitions_per_task : None ,
371404 } ;
372405 opt. register_tables ( & ctx) . await ?;
373406 let queries = get_query_sql ( query) ?;
@@ -405,6 +438,7 @@ mod tests {
405438 disable_statistics : false ,
406439 prefer_hash_join : true ,
407440 sorted : false ,
441+ partitions_per_task : None ,
408442 } ;
409443 opt. register_tables ( & ctx) . await ?;
410444 let queries = get_query_sql ( query) ?;
0 commit comments