Skip to content

Commit e2c9817

Browse files
committed
[#10063][feat] AutoDeploy attention dp support
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent 069ad68 commit e2c9817

File tree

4 files changed

+90
-21
lines changed

4 files changed

+90
-21
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,14 @@ def wrapper(
351351
def _call_func():
352352
return func(self, scheduled_requests, resource_manager, *args, **kwargs)
353353

354-
# check if we use cuda graph and we can run it
355-
if not (self.cuda_graph_used and scheduled_requests.can_run_cuda_graph):
356-
return _call_func()
354+
# check conditions for current rank
355+
can_run_cuda_graph = self.cuda_graph_used and scheduled_requests.can_run_cuda_graph
356+
batch_size = scheduled_requests.batch_size
357357

358358
# generate a persistent dummy request right away to ensure we can reserve the necessary
359-
# resources (kv page and slot)
360-
if self.padding_dummy_request is None:
359+
# resources (kv page and slot) the first time we can actually run cuda graph according to
360+
# this rank
361+
if can_run_cuda_graph and self.padding_dummy_request is None:
361362
self.padding_dummy_request = _generate_dummy_request(
362363
resource_manager,
363364
request_id=CUDA_GRAPH_DUMMY_REQUEST_ID,
@@ -367,20 +368,45 @@ def _call_func():
367368
max_beam_width=self.max_beam_width,
368369
)
369370

370-
# check closest cuda graph batch size
371-
closest_cg_bs = _round_up_to_closest(
372-
self.cuda_graph_batch_sizes, scheduled_requests.batch_size
373-
)
371+
# check if we can pad the batch based on the availability of the dummy request
372+
can_pad = self.padding_dummy_request is not None
373+
374+
# in attention DP mode, we check all ranks
375+
if self.enable_attention_dp and self.mapping.tp_size > 1:
376+
assert self.dist is not None, "Distributed object is required for attention DP mode"
377+
all_rank_info = self.dist.tp_allgather([can_run_cuda_graph, can_pad, batch_size])
378+
else:
379+
all_rank_info = [[can_run_cuda_graph, can_pad, batch_size]]
380+
381+
# now let's check if we can run cuda graph and pad the batch for all ranks
382+
can_run_cuda_graph_all = all(r_info[0] for r_info in all_rank_info)
383+
max_batch_size = max(r_info[2] for r_info in all_rank_info)
384+
385+
# let's check if all ranks can pad the batch if they need to
386+
can_pad_all = all(r_info[1] or (r_info[2] == max_batch_size) for r_info in all_rank_info)
387+
388+
# fall back if we cannot run cudagraph
389+
if not (can_run_cuda_graph_all and can_pad_all):
390+
return _call_func()
374391

375-
# check if we need to pad
376-
num_padding = closest_cg_bs - scheduled_requests.batch_size
392+
# check if cudagraph batch size is available
393+
# NOTE: we assume uniform cudagraph batch sizes across all ranks ensuring all ranks get the
394+
# same closest cudagraph batch size here based on the max batch size across all ranks
395+
closest_cg_bs = _round_up_to_closest(self.cuda_graph_batch_sizes, max_batch_size)
377396

378-
if num_padding <= 0:
397+
if closest_cg_bs is None:
379398
return _call_func()
380399

381-
# check if we have a dummy request to use
382-
if self.padding_dummy_request is None:
383-
ad_logger.info("No CUDA graph padding possible due to missing dummy request.")
400+
# check actual amount of padding needed
401+
num_padding = closest_cg_bs - batch_size
402+
403+
# we should only hit this point for either of these conditions
404+
assert num_padding == 0 or (num_padding > 0 and self.padding_dummy_request is not None), (
405+
"Padding should not be needed or available at this point"
406+
)
407+
408+
# no padding needed on current rank
409+
if num_padding == 0:
384410
return _call_func()
385411

386412
# pad the scheduled requests with the dummy request
@@ -411,7 +437,12 @@ def _device(self) -> DeviceLikeType:
411437
return self.cache_seq_interface.device
412438

413439
@classmethod
414-
def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None):
440+
def build_from_config(
441+
cls,
442+
ad_config: LlmArgs,
443+
mapping: Optional[Mapping] = None,
444+
dist: Optional[Distributed] = None,
445+
):
415446
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
416447

417448
max_batch_size = ad_config.max_batch_size
@@ -453,6 +484,7 @@ def build_from_config(cls, ad_config: LlmArgs, mapping: Optional[Mapping] = None
453484
device,
454485
ad_config=ad_config,
455486
mapping=mapping,
487+
dist=dist,
456488
reporting_info=reporting_info,
457489
)
458490

@@ -464,6 +496,7 @@ def __init__(
464496
device: DeviceLikeType,
465497
ad_config: Optional[LlmArgs] = None,
466498
mapping: Optional[Mapping] = None,
499+
dist: Optional[Distributed] = None,
467500
reporting_info: ReportingInfo = ReportingInfo(),
468501
) -> None:
469502
"""Initialize the engine with model and sequence information."""
@@ -484,7 +517,7 @@ def __init__(
484517
self.iter_states = {}
485518

486519
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
487-
self.enable_attention_dp = False
520+
self.enable_attention_dp = mapping.enable_attention_dp if mapping else False
488521

489522
if ad_config is not None:
490523
self.max_beam_width = ad_config.max_beam_width
@@ -537,6 +570,7 @@ def __init__(
537570

538571
# Reuse _execute_logit_post_processors from PyTorchModelEngine
539572
self.mapping = mapping
573+
self.dist = dist
540574
self._execute_logit_post_processors = types.MethodType(
541575
PyTorchModelEngine._execute_logit_post_processors, self
542576
)
@@ -1005,13 +1039,23 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10051039
# initialize process groups
10061040
world_size = mpi_world_size()
10071041
rank = mpi_rank()
1008-
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
1042+
enable_attention_dp = ad_config.transforms.get("detect_sharding", {}).get(
1043+
"enable_attention_dp", False
1044+
)
1045+
dist_mapping = Mapping(
1046+
rank=rank,
1047+
world_size=world_size,
1048+
tp_size=world_size,
1049+
enable_attention_dp=enable_attention_dp,
1050+
)
10091051
dist = Distributed.get(dist_mapping)
10101052
ad_logger.set_rank(rank)
10111053
torch.cuda.set_device(rank)
10121054
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
10131055
initialize_or_skip(rank, world_size, port)
10141056

1057+
ad_logger.info(f"{dist_mapping=}, {dist=}, {port=}")
1058+
10151059
# Setup AutoTuner with distributed state for allreduce autotuning
10161060
AutoTuner.get().setup_distributed_state(dist_mapping)
10171061

@@ -1030,7 +1074,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
10301074
)
10311075

10321076
# initialize model engine
1033-
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping)
1077+
engine = ADEngine.build_from_config(ad_config=ad_config, mapping=dist_mapping, dist=dist)
10341078

10351079
spec_config = ad_config.speculative_config
10361080
if spec_config is not None and not (

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ class ShardingTransformConfig(TransformConfig):
150150

151151
process_grid: Dict[ShardingDim, int] = Field(default_factory=dict)
152152

153+
enable_attention_dp: bool = Field(
154+
default=False,
155+
description="When True, skip TP sharding as attention data parallelism is enabled.",
156+
)
157+
153158
def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool:
154159
init_process_grid_from_config(self)
155160
if sources is None:
@@ -737,8 +742,9 @@ def _apply(
737742
f"Using allreduce strategy: {config.allreduce_strategy.name}, dist backend: {config.dist_backend}"
738743
)
739744

740-
if world_size < 2:
741-
ad_logger.info("Skipping sharding for single device")
745+
if world_size < 2 or config.enable_attention_dp:
746+
reason = "single device" if world_size < 2 else "attention DP enabled"
747+
ad_logger.info(f"Skipping sharding: {reason}")
742748
return gm, TransformInfo(
743749
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
744750
)

tests/integration/defs/accuracy/test_llm_api_autodeploy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,24 @@ def test_auto_dtype(self, world_size, enable_chunked_prefill):
7979
task = MMLU(self.MODEL_NAME)
8080
task.evaluate(llm, sampling_params=sampling_params)
8181

82+
@pytest.mark.skip_less_device_memory(32000)
83+
@pytest.mark.skip_less_device(2)
84+
@pytest.mark.parametrize("world_size", [2, 4])
85+
def test_attention_dp(self, world_size):
86+
"""Test attention data parallelism mode where TP sharding is disabled."""
87+
kwargs = self.get_default_kwargs(enable_chunked_prefill=True)
88+
# Enable attention DP - this disables TP sharding
89+
kwargs["transforms"]["detect_sharding"] = {"enable_attention_dp": True}
90+
sampling_params = self.get_default_sampling_params()
91+
with AutoDeployLLM(model=self.MODEL_PATH,
92+
tokenizer=self.MODEL_PATH,
93+
world_size=world_size,
94+
**kwargs) as llm:
95+
task = CnnDailymail(self.MODEL_NAME)
96+
task.evaluate(llm)
97+
task = MMLU(self.MODEL_NAME)
98+
task.evaluate(llm, sampling_params=sampling_params)
99+
82100

83101
class TestNemotronH(LlmapiAccuracyTestHarness):
84102
MODEL_NAME = "nvidia/Nemotron-H-8B-Base-8K"

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,4 +321,5 @@ l0_dgx_h100:
321321
tests:
322322
- unittest/_torch/auto_deploy/unit/multigpu
323323
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
324+
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_attention_dp[4]
324325
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16

0 commit comments

Comments
 (0)