@@ -378,27 +378,30 @@ def _call_func():
378378 else :
379379 all_rank_info = [[can_run_cuda_graph , can_pad , batch_size ]]
380380
381- # now let's check if we can run cuda graph and pad the batch for all ranks
381+ # now let's check if we can in principle run cuda graph across all ranks
382382 can_run_cuda_graph_all = all (r_info [0 ] for r_info in all_rank_info )
383- max_batch_size = max (r_info [2 ] for r_info in all_rank_info )
384-
385- # let's check if all ranks can pad the batch if they need to
386- can_pad_all = all (r_info [1 ] or (r_info [2 ] == max_batch_size ) for r_info in all_rank_info )
387383
388- # fall back if we cannot run cudagraph
389- if not (can_run_cuda_graph_all and can_pad_all ):
384+ if not can_run_cuda_graph_all :
390385 return _call_func ()
391386
392- # check if cudagraph batch size is available
387+ # get closest cudagraph batch size based on max_batch_size across ALL ranks
393388 # NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
394389 # same closest cudagraph batch size here based on the max batch size across all ranks
395- closest_cg_bs = _round_up_to_closest (self .cuda_graph_batch_sizes , max_batch_size )
390+ max_batch_size = max (r_info [2 ] for r_info in all_rank_info )
391+ cg_batch_size = _round_up_to_closest (self .cuda_graph_batch_sizes , max_batch_size )
392+
393+ if cg_batch_size is None :
394+ return _call_func ()
395+
396+ # let's check if all ranks can pad the batch if they need to
397+ can_pad_all = all (r_info [1 ] or (r_info [2 ] == cg_batch_size ) for r_info in all_rank_info )
396398
397- if closest_cg_bs is None :
399+ # fall back if we cannot run cudagraph due to padding issues
400+ if not can_pad_all :
398401 return _call_func ()
399402
400403 # check actual amount of padding needed
401- num_padding = closest_cg_bs - batch_size
404+ num_padding = cg_batch_size - batch_size
402405
403406 # we should only hit this point for either of these conditions
404407 assert num_padding == 0 or (num_padding > 0 and self .padding_dummy_request is not None ), (
0 commit comments