@@ -42,9 +42,8 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan;
4242use datafusion:: physical_plan:: { collect, displayable} ;
4343use datafusion:: prelude:: * ;
4444use datafusion_distributed:: test_utils:: localhost:: {
45- get_free_ports , spawn_flight_service, LocalHostChannelResolver ,
45+ spawn_flight_service, LocalHostChannelResolver ,
4646} ;
47- use datafusion_distributed:: MappedDistributedSessionBuilderExt ;
4847use datafusion_distributed:: {
4948 DistributedExt , DistributedPhysicalOptimizerRule , DistributedSessionBuilder ,
5049 DistributedSessionBuilderContext ,
@@ -101,19 +100,19 @@ pub struct RunOpt {
101100 #[ structopt( short = "t" , long = "sorted" ) ]
102101 sorted : bool ,
103102
104- /// Run in distributed mode.
105- #[ structopt( short = "D" , long = "distributed" ) ]
106- distributed : bool ,
107-
108103 /// Number of partitions per task.
109104 #[ structopt( long = "ppt" ) ]
110105 partitions_per_task : Option < usize > ,
111106
112- /// Number of physical threads per worker (default 1)
113- #[ structopt( long, default_value = "1" ) ]
114- workers : usize ,
107+ /// Spawns a worker in the specified port.
108+ #[ structopt( long) ]
109+ spawn : Option < u16 > ,
110+
111+ /// The ports of all the workers involved in the query.
112+ #[ structopt( long, use_delimiter = true ) ]
113+ workers : Vec < u16 > ,
115114
116- /// Number of physical threads per worker
115+ /// Number of physical threads per worker.
117116 #[ structopt( long) ]
118117 threads : Option < usize > ,
119118}
@@ -126,7 +125,7 @@ impl DistributedSessionBuilder for RunOpt {
126125 ) -> Result < SessionState , DataFusionError > {
127126 let mut builder = SessionStateBuilder :: new ( ) . with_default_features ( ) ;
128127
129- let config = self
128+ let mut config = self
130129 . common
131130 . config ( ) ?
132131 . with_collect_statistics ( !self . disable_statistics )
@@ -139,11 +138,13 @@ impl DistributedSessionBuilder for RunOpt {
139138 if self . mem_table {
140139 builder = builder. with_physical_optimizer_rule ( Arc :: new ( InMemoryDataSourceRule ) ) ;
141140 }
142- if self . distributed {
141+ if ! self . workers . is_empty ( ) {
143142 let mut rule = DistributedPhysicalOptimizerRule :: new ( ) ;
144143 if let Some ( partitions_per_task) = self . partitions_per_task {
145144 rule = rule. with_maximum_partitions_per_task ( partitions_per_task)
146145 }
146+ let ports = self . workers . clone ( ) ;
147+ config = config. with_distributed_channel_resolver ( LocalHostChannelResolver :: new ( ports) ) ;
147148 builder = builder. with_physical_optimizer_rule ( Arc :: new ( rule) ) ;
148149 }
149150
@@ -156,61 +157,25 @@ impl DistributedSessionBuilder for RunOpt {
156157
157158impl RunOpt {
158159 pub fn run ( self ) -> Result < ( ) > {
159- let ports = get_free_ports ( self . workers ) ;
160-
161- let _worker_handles = self . clone ( ) . spawn_workers ( ports. clone ( ) ) ;
162-
163160 let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
164161 . worker_threads ( self . threads . unwrap_or ( get_available_parallelism ( ) ) )
165162 . enable_all ( )
166163 . build ( ) ?;
167164
168- rt. block_on ( async move { self . run_local ( ports) . await } )
169- }
170-
171- pub fn spawn_workers ( self , ports : Vec < u16 > ) -> Vec < std:: thread:: JoinHandle < ( ) > > {
172- let threads_per_worker = self . threads ;
173- let ports_copy = ports. clone ( ) ;
174- let session_builder = self . map ( move |builder : SessionStateBuilder | {
175- let channel_resolver = LocalHostChannelResolver :: new ( ports. clone ( ) ) ;
176- Ok ( builder
177- . with_distributed_channel_resolver ( channel_resolver)
178- . build ( ) )
179- } ) ;
180- let mut handles = vec ! [ ] ;
181- for port in ports_copy {
182- let session_builder = session_builder. clone ( ) ;
183- let handle = std:: thread:: spawn ( move || {
184- let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
185- . worker_threads ( threads_per_worker. unwrap_or ( get_available_parallelism ( ) ) )
186- . enable_all ( )
187- . build ( )
188- . unwrap ( ) ;
189- rt. block_on ( async move {
190- let listener = TcpListener :: bind ( format ! ( "127.0.0.1:{port}" ) )
191- . await
192- . unwrap ( ) ;
193- spawn_flight_service ( session_builder, listener)
194- . await
195- . unwrap ( ) ;
196- } )
197- } ) ;
198-
199- handles. push ( handle) ;
165+ if let Some ( port) = self . spawn {
166+ rt. block_on ( async move {
167+ let listener = TcpListener :: bind ( format ! ( "127.0.0.1:{port}" ) ) . await ?;
168+ println ! ( "Listening on {}..." , listener. local_addr( ) . unwrap( ) ) ;
169+ spawn_flight_service ( self , listener) . await
170+ } ) ?;
171+ } else {
172+ rt. block_on ( self . run_local ( ) ) ?;
200173 }
201- handles
174+ Ok ( ( ) )
202175 }
203176
204- async fn run_local ( mut self , ports : Vec < u16 > ) -> Result < ( ) > {
205- let session_builder = self . clone ( ) . map ( move |builder : SessionStateBuilder | {
206- let channel_resolver = LocalHostChannelResolver :: new ( ports. clone ( ) ) ;
207- Ok ( builder
208- . with_distributed_channel_resolver ( channel_resolver)
209- . build ( ) )
210- } ) ;
211- let state = session_builder
212- . build_session_state ( DistributedSessionBuilderContext :: default ( ) )
213- . await ?;
177+ async fn run_local ( mut self ) -> Result < ( ) > {
178+ let state = self . build_session_state ( Default :: default ( ) ) . await ?;
214179 let ctx = SessionContext :: new_with_state ( state) ;
215180 self . register_tables ( & ctx) . await ?;
216181
@@ -229,12 +194,11 @@ impl RunOpt {
229194 for query_id in query_range. clone ( ) {
230195 // put the WarmingUpMarker in the context, otherwise, queries will fail as the
231196 // InMemoryCacheExec node will think they should already be warmed up.
232- let sql = & get_query_sql ( query_id) ?;
233197 let ctx = ctx
234198 . clone ( )
235199 . with_distributed_option_extension ( WarmingUpMarker :: warming_up ( ) ) ?;
236- for query in sql . iter ( ) {
237- self . execute_query ( & ctx, query) . await ?;
200+ for query in get_query_sql ( query_id ) ? {
201+ self . execute_query ( & ctx, & query) . await ?;
238202 }
239203 println ! ( "Query {query_id} data loaded in memory" ) ;
240204 }
0 commit comments