Skip to content

Commit 97b4e05

Browse files
committed
fix cudagraph padding logic
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent abe6030 commit 97b4e05

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -378,27 +378,30 @@ def _call_func():
378378
else:
379379
all_rank_info = [[can_run_cuda_graph, can_pad, batch_size]]
380380

381-
# now let's check if we can run cuda graph and pad the batch for all ranks
381+
# now let's check if we can in principle run cuda graph across all ranks
382382
can_run_cuda_graph_all = all(r_info[0] for r_info in all_rank_info)
383-
max_batch_size = max(r_info[2] for r_info in all_rank_info)
384-
385-
# let's check if all ranks can pad the batch if they need to
386-
can_pad_all = all(r_info[1] or (r_info[2] == max_batch_size) for r_info in all_rank_info)
387383

388-
# fall back if we cannot run cudagraph
389-
if not (can_run_cuda_graph_all and can_pad_all):
384+
if not can_run_cuda_graph_all:
390385
return _call_func()
391386

392-
# check if cudagraph batch size is available
387+
# get closest cudagraph batch size based on max_batch_size across ALL ranks
393388
# NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
394389
# same closest cudagraph batch size here based on the max batch size across all ranks
395-
closest_cg_bs = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size)
390+
max_batch_size = max(r_info[2] for r_info in all_rank_info)
391+
cg_batch_size = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size)
392+
393+
if cg_batch_size is None:
394+
return _call_func()
395+
396+
# let's check if all ranks can pad the batch if they need to
397+
can_pad_all = all(r_info[1] or (r_info[2] == cg_batch_size) for r_info in all_rank_info)
396398

397-
if closest_cg_bs is None:
399+
# fall back if we cannot run cudagraph due to padding issues
400+
if not can_pad_all:
398401
return _call_func()
399402

400403
# check actual amount of padding needed
401-
num_padding = closest_cg_bs - batch_size
404+
num_padding = cg_batch_size - batch_size
402405

403406
# we should only hit this point for either of these conditions
404407
assert num_padding == 0 or (num_padding > 0 and self.padding_dummy_request is not None), (

0 commit comments

Comments
 (0)