@@ -351,13 +351,14 @@ def wrapper(
351351 def _call_func ():
352352 return func (self , scheduled_requests , resource_manager , * args , ** kwargs )
353353
354- # check if we use cuda graph and we can run it
355- if not ( self .cuda_graph_used and scheduled_requests .can_run_cuda_graph ):
356- return _call_func ()
354+ # check conditions for current rank
355+ can_run_cuda_graph = self .cuda_graph_used and scheduled_requests .can_run_cuda_graph
356+ batch_size = scheduled_requests . batch_size
357357
358358 # generate a persistent dummy request right away to ensure we can reserve the necessary
359- # resources (kv page and slot)
360- if self .padding_dummy_request is None :
359+ # resources (kv page and slot) the first time we can actually run cuda graph according to
360+ # this rank
361+ if can_run_cuda_graph and self .padding_dummy_request is None :
361362 self .padding_dummy_request = _generate_dummy_request (
362363 resource_manager ,
363364 request_id = CUDA_GRAPH_DUMMY_REQUEST_ID ,
@@ -367,20 +368,45 @@ def _call_func():
367368 max_beam_width = self .max_beam_width ,
368369 )
369370
370- # check closest cuda graph batch size
371- closest_cg_bs = _round_up_to_closest (
372- self .cuda_graph_batch_sizes , scheduled_requests .batch_size
373- )
371+ # check if we can pad the batch based on the availability of the dummy request
372+ can_pad = self .padding_dummy_request is not None
373+
374+ # in attention DP mode, we check all ranks
375+ if self .enable_attention_dp and self .mapping .tp_size > 1 :
376+ assert self .dist is not None , "Distributed object is required for attention DP mode"
377+ all_rank_info = self .dist .tp_allgather ([can_run_cuda_graph , can_pad , batch_size ])
378+ else :
379+ all_rank_info = [[can_run_cuda_graph , can_pad , batch_size ]]
380+
381+ # now let's check if we can run cuda graph and pad the batch for all ranks
382+ can_run_cuda_graph_all = all (r_info [0 ] for r_info in all_rank_info )
383+ max_batch_size = max (r_info [2 ] for r_info in all_rank_info )
384+
385+ # let's check if all ranks can pad the batch if they need to
386+ can_pad_all = all (r_info [1 ] or (r_info [2 ] == max_batch_size ) for r_info in all_rank_info )
387+
388+ # fall back if we cannot run cudagraph
389+ if not (can_run_cuda_graph_all and can_pad_all ):
390+ return _call_func ()
374391
375- # check if we need to pad
376- num_padding = closest_cg_bs - scheduled_requests .batch_size
392+ # check if cudagraph batch size is available
393+ # NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
394+ # same closest cudagraph batch size here based on the max batch size across all ranks
395+ closest_cg_bs = _round_up_to_closest (self .cuda_graph_batch_sizes , max_batch_size )
377396
378- if num_padding <= 0 :
397+ if closest_cg_bs is None :
379398 return _call_func ()
380399
381- # check if we have a dummy request to use
382- if self .padding_dummy_request is None :
383- ad_logger .info ("No CUDA graph padding possible due to missing dummy request." )
400+ # check actual amount of padding needed
401+ num_padding = closest_cg_bs - batch_size
402+
403+ # we should only hit this point for either of these conditions
404+ assert num_padding == 0 or (num_padding > 0 and self .padding_dummy_request is not None ), (
405+ "Padding should not be needed or available at this point"
406+ )
407+
408+ # no padding needed on current rank
409+ if num_padding == 0 :
384410 return _call_func ()
385411
386412 # pad the scheduled requests with the dummy request
@@ -411,7 +437,12 @@ def _device(self) -> DeviceLikeType:
411437 return self .cache_seq_interface .device
412438
413439 @classmethod
414- def build_from_config (cls , ad_config : LlmArgs , mapping : Optional [Mapping ] = None ):
440+ def build_from_config (
441+ cls ,
442+ ad_config : LlmArgs ,
443+ mapping : Optional [Mapping ] = None ,
444+ dist : Optional [Distributed ] = None ,
445+ ):
415446 """Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
416447
417448 max_batch_size = ad_config .max_batch_size
@@ -453,6 +484,7 @@ def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None
453484 device ,
454485 ad_config = ad_config ,
455486 mapping = mapping ,
487+ dist = dist ,
456488 reporting_info = reporting_info ,
457489 )
458490
@@ -464,6 +496,7 @@ def __init__(
464496 device : DeviceLikeType ,
465497 ad_config : Optional [LlmArgs ] = None ,
466498 mapping : Optional [Mapping ] = None ,
499+ dist : Optional [Distributed ] = None ,
467500 reporting_info : ReportingInfo = ReportingInfo (),
468501 ) -> None :
469502 """Initialize the engine with model and sequence information."""
@@ -484,7 +517,7 @@ def __init__(
484517 self .iter_states = {}
485518
486519 # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
487- self .enable_attention_dp = False
520+ self .enable_attention_dp = mapping . enable_attention_dp if mapping else False
488521
489522 if ad_config is not None :
490523 self .max_beam_width = ad_config .max_beam_width
@@ -537,6 +570,7 @@ def __init__(
537570
538571 # Reuse _execute_logit_post_processors from PyTorchModelEngine
539572 self .mapping = mapping
573+ self .dist = dist
540574 self ._execute_logit_post_processors = types .MethodType (
541575 PyTorchModelEngine ._execute_logit_post_processors , self
542576 )
@@ -1005,13 +1039,23 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10051039 # initialize process groups
10061040 world_size = mpi_world_size ()
10071041 rank = mpi_rank ()
1008- dist_mapping = Mapping (rank = rank , world_size = world_size , tp_size = world_size )
1042+ enable_attention_dp = ad_config .transforms .get ("detect_sharding" , {}).get (
1043+ "enable_attention_dp" , False
1044+ )
1045+ dist_mapping = Mapping (
1046+ rank = rank ,
1047+ world_size = world_size ,
1048+ tp_size = world_size ,
1049+ enable_attention_dp = enable_attention_dp ,
1050+ )
10091051 dist = Distributed .get (dist_mapping )
10101052 ad_logger .set_rank (rank )
10111053 torch .cuda .set_device (rank )
10121054 port = dist .broadcast (get_free_port ()) # use MPI broadcast to pick a free port
10131055 initialize_or_skip (rank , world_size , port )
10141056
1057+ ad_logger .info (f"{ dist_mapping = } , { dist = } , { port = } " )
1058+
10151059 # Setup AutoTuner with distributed state for allreduce autotuning
10161060 AutoTuner .get ().setup_distributed_state (dist_mapping )
10171061
@@ -1030,7 +1074,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10301074 )
10311075
10321076 # initialize model engine
1033- engine = ADEngine .build_from_config (ad_config = ad_config , mapping = dist_mapping )
1077+ engine = ADEngine .build_from_config (ad_config = ad_config , mapping = dist_mapping , dist = dist )
10341078
10351079 spec_config = ad_config .speculative_config
10361080 if spec_config is not None and not (
0 commit comments