11use std:: {
2- collections:: HashMap ,
3- env,
4- sync:: { Arc , LazyLock } ,
2+ collections:: HashMap , env, sync:: { Arc , LazyLock }
53} ;
64
75use anyhow:: { anyhow, Context } ;
@@ -30,19 +28,10 @@ use itertools::Itertools;
3028use prost:: Message ;
3129
3230use crate :: {
33- analyze:: { DistributedAnalyzeExec , DistributedAnalyzeRootExec } ,
34- isolator:: PartitionIsolatorExec ,
35- logging:: { debug, error, info, trace} ,
36- max_rows:: MaxRowsExec ,
37- physical:: DDStageOptimizerRule ,
38- result:: { DDError , Result } ,
39- stage:: DDStageExec ,
40- stage_reader:: { DDStageReaderExec , QueryId } ,
41- util:: { display_plan_with_partition_counts, get_client, physical_plan_to_bytes, wait_for} ,
42- vocab:: {
31+ analyze:: { DistributedAnalyzeExec , DistributedAnalyzeRootExec } , isolator:: PartitionIsolatorExec , logging:: { debug, info, trace, error} , max_rows:: MaxRowsExec , physical:: DDStageOptimizerRule , result:: { DDError , Result } , stage:: DDStageExec , stage_reader:: { DDStageReaderExec , QueryId } , transport:: WorkerTransport , util:: { display_plan_with_partition_counts, physical_plan_to_bytes, wait_for} , vocab:: {
4332 Addrs , CtxAnnotatedOutputs , CtxHost , CtxPartitionGroup , CtxStageAddrs , CtxStageId , DDTask ,
4433 Host , Hosts , PartitionAddrs , StageAddrs ,
45- } ,
34+ }
4635} ;
4736
4837#[ derive( Debug ) ]
@@ -439,109 +428,84 @@ pub fn add_distributed_analyze(
439428pub async fn distribute_stages (
440429 query_id : & str ,
441430 stages : Vec < DDStage > ,
442- worker_addrs : Vec < Host > ,
431+ workers : Vec < ( Host , Arc < dyn WorkerTransport > ) > ,
443432 codec : & dyn PhysicalExtensionCodec ,
444433) -> Result < ( Addrs , Vec < DDTask > ) > {
445- // map of worker name to address
446- // FIXME: use types over tuples of strings, as we can accidently swap them and
447- // not know
448-
449- // a map of worker name to host
450- let mut workers: HashMap < String , Host > = worker_addrs
451- . iter ( )
452- . map ( |host| ( host. name . clone ( ) , host. clone ( ) ) )
434+ // materialise a name-keyed map so we can remove “bad” workers on each retry
435+ let mut valid_workers: HashMap < _ , _ > = workers
436+ . into_iter ( )
437+ . map ( |( h, tx) | ( h. name . clone ( ) , ( h, tx) ) )
453438 . collect ( ) ;
454439
455440 for attempt in 0 ..3 {
456- if workers . is_empty ( ) {
441+ if valid_workers . is_empty ( ) {
457442 return Err ( anyhow ! ( "No workers available to distribute stages" ) . into ( ) ) ;
458443 }
459444
460- // all stages to workers
461- let ( task_datas, final_addrs) =
462- assign_to_workers ( query_id, & stages, workers. values ( ) . collect ( ) , codec) ?;
445+ let current: Vec < _ > = valid_workers. values ( ) . cloned ( ) . collect ( ) ;
446+ let ( tasks, final_addrs, tx_host_pairs) =
447+ assign_to_workers ( query_id, & stages, current, codec) ?;
448+
449+ match try_distribute_tasks ( & tasks, & tx_host_pairs) . await {
450+ Ok ( _) => return Ok ( ( final_addrs, tasks) ) ,
463451
464- // we retry this a few times to ensure that the workers are ready
465- // and can accept the stages
466- match try_distribute_tasks ( & task_datas) . await {
467- Ok ( _) => return Ok ( ( final_addrs, task_datas) ) ,
468- Err ( DDError :: WorkerCommunicationError ( bad_worker) ) => {
452+ // remove the poisoned worker and retry on the non poisoned workers
453+ Err ( DDError :: WorkerCommunicationError ( bad_host) ) => {
469454 error ! (
470- "distribute stages for query {query_id} attempt {attempt} failed removing \
471- worker {bad_worker}. Retrying..."
455+ "distribute_stages: attempt {attempt} – \
456+ worker {} failed; will retry without it",
457+ bad_host. name
472458 ) ;
473- // if we cannot communicate with a worker, we remove it from the list of workers
474- workers. remove ( & bad_worker. name ) ;
459+ valid_workers. remove ( & bad_host. name ) ;
475460 }
461+
462+ // any other error is terminal
476463 Err ( e) => return Err ( e) ,
477464 }
478- if attempt == 2 {
479- return Err (
480- anyhow ! ( "Failed to distribute query {query_id} stages after 3 attempts" ) . into ( ) ,
481- ) ;
482- }
483465 }
484- unreachable ! ( )
466+
467+ unreachable ! ( "retry loop exits on success or early return on error" ) ;
485468}
486469
487470/// try to distribute the stages to the workers, if we cannot communicate with a
488471/// worker return it as the element in the Err
489- async fn try_distribute_tasks ( task_datas : & [ DDTask ] ) -> Result < ( ) > {
490- // we can use the stage data to distribute the stages to workers
491- for task_data in task_datas {
472+ async fn try_distribute_tasks (
473+ tasks : & [ DDTask ] ,
474+ tx_host_pairs : & [ ( Arc < dyn WorkerTransport > , Host ) ] ,
475+ ) -> Result < ( ) > {
476+ for ( ( task, ( tx, host) ) ) in tasks. iter ( ) . zip ( tx_host_pairs) {
492477 trace ! (
493- "Distributing Task: stage_id {}, pg: {:?} to worker: {:? }" ,
494- task_data . stage_id,
495- task_data . partition_group,
496- task_data . assigned_host
478+ "Sending stage {} / pg {:?} to { }" ,
479+ task . stage_id,
480+ task . partition_group,
481+ host
497482 ) ;
498483
499- // populate its child stages
500- let mut stage_data = task_data . clone ( ) ;
501- stage_data . stage_addrs = Some ( get_stage_addrs_from_tasks (
502- & stage_data . child_stage_ids ,
503- task_datas ,
484+ // embed the StageAddrs of all children before shipping
485+ let mut stage = task . clone ( ) ;
486+ stage . stage_addrs = Some ( get_stage_addrs_from_tasks (
487+ & stage . child_stage_ids ,
488+ tasks ,
504489 ) ?) ;
505490
506- let host = stage_data
507- . assigned_host
508- . clone ( )
509- . context ( "Assigned host is missing for task data" ) ?;
510-
511- let mut client = match get_client ( & host) {
512- Ok ( client) => client,
513- Err ( e) => {
514- error ! ( "Couldn't not communicate with worker {e:#?}" ) ;
515- return Err ( DDError :: WorkerCommunicationError (
516- host. clone ( ) , // here
517- ) ) ;
518- }
519- } ;
520-
521- let mut buf = vec ! [ ] ;
522- stage_data
523- . encode ( & mut buf)
524- . context ( "Failed to encode stage data to buf" ) ?;
491+ let mut buf = Vec :: new ( ) ;
492+ stage. encode ( & mut buf) . map_err ( anyhow:: Error :: from) ?;
525493
526494 let action = Action {
527- r#type : "add_plan" . to_string ( ) ,
495+ r#type : "add_plan" . into ( ) ,
528496 body : buf. into ( ) ,
529497 } ;
530498
531- let mut response = client
499+ // gRPC call, if it fails, transport poisons itself on failure and removes the address from the registry
500+ let mut stream = tx
532501 . do_action ( action)
533502 . await
534- . context ( "Failed to send action to worker" ) ?;
503+ . map_err ( |_| DDError :: WorkerCommunicationError ( host . clone ( ) ) ) ?;
535504
536- // consume this empty response to ensure the action was successful
537- while let Some ( _res) = response
538- . try_next ( )
539- . await
540- . context ( "error consuming do_action response" ) ?
541- {
542- // we don't care about the response, just that it was successful
543- }
544- trace ! ( "do action success for stage_id: {}" , stage_data. stage_id) ;
505+ // drain the (empty) response – ensures the worker actually accepted it
506+ while stream. try_next ( ) . await ? != None { }
507+
508+ trace ! ( "stage {} delivered to {}" , stage. stage_id, host) ;
545509 }
546510 Ok ( ( ) )
547511}
@@ -552,40 +516,35 @@ async fn try_distribute_tasks(task_datas: &[DDTask]) -> Result<()> {
552516fn assign_to_workers (
553517 query_id : & str ,
554518 stages : & [ DDStage ] ,
555- worker_addrs : Vec < & Host > ,
519+ workers : Vec < ( Host , Arc < dyn WorkerTransport > ) > ,
556520 codec : & dyn PhysicalExtensionCodec ,
557- ) -> Result < ( Vec < DDTask > , Addrs ) > {
558- let mut task_datas = vec ! [ ] ;
559- let mut worker_idx = 0 ;
521+ ) -> Result < ( Vec < DDTask > , Addrs , Vec < ( Arc < dyn WorkerTransport > , Host ) > ) > {
522+ let mut task_datas = Vec :: new ( ) ;
523+ let mut tx_host_pairs = Vec :: new ( ) ;
560524
561- trace ! (
562- "assigning stages: {:?}" ,
563- stages
564- . iter( )
565- . map( |s| format!( "stage_id: {}, pgs:{:?}" , s. stage_id, s. partition_groups) )
566- . join( ",\n " )
567- ) ;
525+ // round-robin scheduler
526+ let mut idx = 0 ;
527+ let n_workers = workers. len ( ) ;
568528
569- // keep track of which worker has the root of the plan tree (highest stage
570- // number)
571- let mut max_stage_id = -1 ;
529+ // keep track of where the root of the plan will live (highest stage id)
530+ let mut max_stage_id: i64 = -1 ;
572531 let mut final_addrs = Addrs :: default ( ) ;
573532
574533 for stage in stages {
575- for partition_group in stage. partition_groups . iter ( ) {
534+ for pg in & stage. partition_groups {
576535 let plan_bytes = physical_plan_to_bytes ( stage. plan . clone ( ) , codec) ?;
577536
578- let host = worker_addrs[ worker_idx] . clone ( ) ;
579- worker_idx = ( worker_idx + 1 ) % worker_addrs. len ( ) ;
537+ // pick next worker
538+ let ( host, tx) = workers[ idx] . clone ( ) ;
539+ idx = ( idx + 1 ) % n_workers;
580540
581- if stage . stage_id as isize > max_stage_id {
582- // this wasn't the last stage
583- max_stage_id = stage. stage_id as isize ;
541+ // remember which host serves the final stage
542+ if stage . stage_id as i64 > max_stage_id {
543+ max_stage_id = stage. stage_id as i64 ;
584544 final_addrs. clear ( ) ;
585545 }
586- if stage. stage_id as isize == max_stage_id {
587- for part in partition_group. iter ( ) {
588- // we are the final stage, so we will be the one to serve this partition
546+ if stage. stage_id as i64 == max_stage_id {
547+ for part in pg {
589548 final_addrs
590549 . entry ( stage. stage_id )
591550 . or_default ( )
@@ -595,22 +554,24 @@ fn assign_to_workers(
595554 }
596555 }
597556
598- let task_data = DDTask {
599- query_id : query_id. to_string ( ) ,
557+ task_datas . push ( DDTask {
558+ query_id : query_id. to_owned ( ) ,
600559 stage_id : stage. stage_id ,
601560 plan_bytes,
602- partition_group : partition_group . to_vec ( ) ,
603- child_stage_ids : stage. child_stage_ids ( ) . unwrap_or_default ( ) . to_vec ( ) ,
604- stage_addrs : None , // will be calculated and filled in later
561+ partition_group : pg . clone ( ) ,
562+ child_stage_ids : stage. child_stage_ids ( ) . unwrap_or_default ( ) ,
563+ stage_addrs : None , // filled in later
605564 num_output_partitions : stage. plan . output_partitioning ( ) . partition_count ( ) as u64 ,
606565 full_partitions : stage. full_partitions ,
607- assigned_host : Some ( host) ,
608- } ;
609- task_datas. push ( task_data) ;
566+ assigned_host : Some ( host. clone ( ) ) ,
567+ } ) ;
568+
569+ // keep the order **exactly** aligned with task_datas
570+ tx_host_pairs. push ( ( tx, host) ) ;
610571 }
611572 }
612573
613- Ok ( ( task_datas, final_addrs) )
574+ Ok ( ( task_datas, final_addrs, tx_host_pairs ) )
614575}
615576
616577fn get_stage_addrs_from_tasks ( target_stage_ids : & [ u64 ] , stages : & [ DDTask ] ) -> Result < StageAddrs > {
0 commit comments