@@ -63,8 +63,8 @@ use crate::flight_service::BallistaFlightService;
6363use crate :: metrics:: LoggingMetricsCollector ;
6464use crate :: shutdown:: Shutdown ;
6565use crate :: shutdown:: ShutdownNotifier ;
66- use crate :: terminate;
6766use crate :: { execution_loop, executor_server} ;
67+ use crate :: { terminate, ArrowFlightServerProvider } ;
6868
6969pub struct ExecutorProcessConfig {
7070 pub bind_host : String ,
@@ -101,6 +101,8 @@ pub struct ExecutorProcessConfig {
101101 pub override_logical_codec : Option < Arc < dyn LogicalExtensionCodec > > ,
102102 /// [PhysicalExtensionCodec] override option
103103 pub override_physical_codec : Option < Arc < dyn PhysicalExtensionCodec > > ,
104+ /// [ArrowFlightServerProvider] implementation override option
105+ pub override_arrow_flight_service : Option < Arc < ArrowFlightServerProvider > > ,
104106}
105107
106108impl ExecutorProcessConfig {
@@ -143,6 +145,7 @@ impl Default for ExecutorProcessConfig {
143145 override_config_producer : None ,
144146 override_logical_codec : None ,
145147 override_physical_codec : None ,
148+ override_arrow_flight_service : None ,
146149 }
147150 }
148151}
@@ -151,7 +154,7 @@ pub async fn start_executor_process(
151154 opt : Arc < ExecutorProcessConfig > ,
152155) -> ballista_core:: error:: Result < ( ) > {
153156 let addr = format ! ( "{}:{}" , opt. bind_host, opt. port) ;
154- let addr = addr. parse ( ) . map_err ( |e : std:: net:: AddrParseError | {
157+ let address = addr. parse ( ) . map_err ( |e : std:: net:: AddrParseError | {
155158 BallistaError :: Configuration ( e. to_string ( ) )
156159 } ) ?;
157160
@@ -174,9 +177,12 @@ pub async fn start_executor_process(
174177 opt. concurrent_tasks
175178 } ;
176179
177- info ! ( "Running with config:" ) ;
178- info ! ( "work_dir: {}" , work_dir) ;
179- info ! ( "concurrent_tasks: {}" , concurrent_tasks) ;
180+ info ! (
181+ "Executor starting ... (Datafusion Ballista {})" ,
182+ BALLISTA_VERSION
183+ ) ;
184+ info ! ( "Executor working directory: {}" , work_dir) ;
185+ info ! ( "Executor number of concurrent tasks: {}" , concurrent_tasks) ;
180186
181187 // assign this executor an unique ID
182188 let executor_id = Uuid :: new_v4 ( ) . to_string ( ) ;
@@ -261,16 +267,16 @@ pub async fn start_executor_process(
261267 "Could not connect to scheduler" . to_string ( ) ,
262268 )
263269 } ) {
264- Ok ( conn ) => {
270+ Ok ( connection ) => {
265271 info ! ( "Connected to scheduler at {}" , scheduler_url) ;
266- x = Some ( conn ) ;
272+ x = Some ( connection ) ;
267273 }
268274 Err ( e) => {
269275 warn ! (
270276 "Failed to connect to scheduler at {} ({}); retrying ..." ,
271277 scheduler_url, e
272278 ) ;
273- std :: thread :: sleep ( time:: Duration :: from_millis ( 500 ) ) ;
279+ tokio :: time :: sleep ( time:: Duration :: from_millis ( 500 ) ) . await ;
274280 }
275281 }
276282 }
@@ -290,13 +296,15 @@ pub async fn start_executor_process(
290296 let job_data_ttl_seconds = opt. job_data_ttl_seconds ;
291297
292298 // Graceful shutdown notification
293- let shutdown_noti = ShutdownNotifier :: new ( ) ;
299+ let shutdown_notification = ShutdownNotifier :: new ( ) ;
294300
295301 if opt. job_data_clean_up_interval_seconds > 0 {
296302 let mut interval_time =
297303 time:: interval ( Duration :: from_secs ( opt. job_data_clean_up_interval_seconds ) ) ;
298- let mut shuffle_cleaner_shutdown = shutdown_noti. subscribe_for_shutdown ( ) ;
299- let shuffle_cleaner_complete = shutdown_noti. shutdown_complete_tx . clone ( ) ;
304+
305+ let mut shuffle_cleaner_shutdown = shutdown_notification. subscribe_for_shutdown ( ) ;
306+ let shuffle_cleaner_complete = shutdown_notification. shutdown_complete_tx . clone ( ) ;
307+
300308 tokio:: spawn ( async move {
301309 // As long as the shutdown notification has not been received
302310 while !shuffle_cleaner_shutdown. is_shutdown ( ) {
@@ -338,7 +346,7 @@ pub async fn start_executor_process(
338346 executor. clone ( ) ,
339347 default_codec,
340348 stop_send,
341- & shutdown_noti ,
349+ & shutdown_notification ,
342350 )
343351 . await ?,
344352 ) ;
@@ -351,10 +359,19 @@ pub async fn start_executor_process(
351359 ) ) ) ;
352360 }
353361 } ;
354- service_handlers. push ( tokio:: spawn ( flight_server_run (
355- addr,
356- shutdown_noti. subscribe_for_shutdown ( ) ,
357- ) ) ) ;
362+ let shutdown = shutdown_notification. subscribe_for_shutdown ( ) ;
363+ let override_flight = opt. override_arrow_flight_service . clone ( ) ;
364+
365+ service_handlers. push ( match override_flight {
366+ None => {
367+ info ! ( "Starting built-in arrow flight service" ) ;
368+ flight_server_task ( address, shutdown) . await
369+ }
370+ Some ( flight_provider) => {
371+ info ! ( "Starting custom, user provided, arrow flight service" ) ;
372+ ( flight_provider) ( address, shutdown)
373+ }
374+ } ) ;
358375
359376 let tasks_drained = TasksDrainedFuture ( executor) ;
360377
@@ -436,7 +453,7 @@ pub async fn start_executor_process(
436453 shutdown_complete_tx,
437454 notify_shutdown,
438455 ..
439- } = shutdown_noti ;
456+ } = shutdown_notification ;
440457
441458 // When `notify_shutdown` is dropped, all components which have `subscribe`d will
442459 // receive the shutdown signal and can exit
@@ -451,25 +468,21 @@ pub async fn start_executor_process(
451468}
452469
453470// Arrow flight service
454- async fn flight_server_run (
455- addr : SocketAddr ,
471+ async fn flight_server_task (
472+ address : SocketAddr ,
456473 mut grpc_shutdown : Shutdown ,
457- ) -> Result < ( ) , BallistaError > {
458- let service = BallistaFlightService :: new ( ) ;
459- let server = FlightServiceServer :: new ( service) ;
460- info ! (
461- "Ballista v{} Rust Executor Flight Server listening on {:?}" ,
462- BALLISTA_VERSION , addr
463- ) ;
464-
465- let shutdown_signal = grpc_shutdown. recv ( ) ;
466- let server_future = create_grpc_server ( )
467- . add_service ( server)
468- . serve_with_shutdown ( addr, shutdown_signal) ;
469-
470- server_future. await . map_err ( |e| {
471- error ! ( "Tonic error, Could not start Executor Flight Server." ) ;
472- BallistaError :: TonicError ( e)
474+ ) -> JoinHandle < Result < ( ) , BallistaError > > {
475+ tokio:: spawn ( async move {
476+ info ! ( "Built-in arrow flight server listening on: {:?}" , address) ;
477+
478+ let server_future = create_grpc_server ( )
479+ . add_service ( FlightServiceServer :: new ( BallistaFlightService :: new ( ) ) )
480+ . serve_with_shutdown ( address, grpc_shutdown. recv ( ) ) ;
481+
482+ server_future. await . map_err ( |e| {
483+ error ! ( "Could not start built-in arrow flight server." ) ;
484+ BallistaError :: TonicError ( e)
485+ } )
473486 } )
474487}
475488
@@ -642,4 +655,36 @@ mod tests {
642655 let count2 = fs:: read_dir ( work_dir. clone ( ) ) . unwrap ( ) . count ( ) ;
643656 assert_eq ! ( count2, 0 ) ;
644657 }
658+
659+ #[ tokio:: test]
660+ async fn test_arrow_flight_provider_ergonomics ( ) {
661+ let config = crate :: executor_process:: ExecutorProcessConfig {
662+ override_arrow_flight_service : Some ( std:: sync:: Arc :: new (
663+ move |address, mut grpc_shutdown| {
664+ tokio:: spawn ( async move {
665+ log:: info!(
666+ "custom arrow flight server listening on: {:?}" ,
667+ address
668+ ) ;
669+
670+ let server_future = ballista_core:: utils:: create_grpc_server ( )
671+ . add_service (
672+ arrow_flight:: flight_service_server:: FlightServiceServer :: new (
673+ crate :: flight_service:: BallistaFlightService :: new ( ) ,
674+ ) ,
675+ )
676+ . serve_with_shutdown ( address, grpc_shutdown. recv ( ) ) ;
677+
678+ server_future. await . map_err ( |e| {
679+ log:: error!( "Could not start built-in arrow flight server." ) ;
680+ ballista_core:: error:: BallistaError :: TonicError ( e)
681+ } )
682+ } )
683+ } ,
684+ ) ) ,
685+ ..Default :: default ( )
686+ } ;
687+
688+ assert ! ( config. override_arrow_flight_service. is_some( ) ) ;
689+ }
645690}
0 commit comments