Skip to content

Commit 8966bc3

Browse files
committed
replace tp_allgather with tp_cp_allgather where apt
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent e66b3c5 commit 8966bc3

File tree

6 files changed

+24
-11
lines changed

6 files changed

+24
-11
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ def _maybe_sync_cache_data(self, strategy: DistributedTuningStrategy,
15471547
def _merge_cache_data(self, custom_op: str):
15481548
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
15491549
merged_cache_data = dict()
1550-
all_cache_data = self._dist.tp_allgather(obj=cache_data)
1550+
all_cache_data = self._dist.tp_cp_allgather(obj=cache_data)
15511551

15521552
for data in all_cache_data:
15531553
for key, value in data.items():

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,19 @@ def tp_cp_allgather(self, obj):
149149
First gathers within CP group, then across TP groups, returning
150150
a flattened list with tp_size * cp_size entries.
151151
"""
152+
# Gather across CP dimension.
152153
if self.cp_size > 1:
153154
obj = self.cp_allgather(obj)
155+
else:
156+
obj = [obj] # Wrap to match cp_allgather output format.
157+
158+
# Gather across TP dimension.
154159
if self.tp_size > 1:
155160
obj = self.tp_allgather(obj)
156-
# Flatten: [[cp0, cp1], [cp0, cp1], ...] -> [tp0_cp0, tp0_cp1, tp1_cp0, ...].
161+
else:
162+
obj = [obj] # Wrap to match tp_allgather output format.
163+
164+
# Flatten: [[cp0, cp1], [cp0, cp1], ...] -> [tp0_cp0, tp0_cp1, tp1_cp0, ...]
157165
return [entry for tp_group in obj for entry in tp_group]
158166

159167

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def maybe_get_cuda_graph(
246246
can_run_cuda_graph = batch.can_run_cuda_graph
247247
batch_size = batch.batch_size
248248
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
249-
all_can_graph_batch = self.config.dist.tp_allgather(
249+
all_can_graph_batch = self.config.dist.tp_cp_allgather(
250250
[can_run_cuda_graph, batch_size])
251251
is_all_gen_only = all(all_can_graph[0]
252252
for all_can_graph in all_can_graph_batch)
@@ -408,7 +408,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
408408
new_batch_size = batch_size
409409

410410
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
411-
graph_batch_size = self.config.dist.tp_allgather(
411+
graph_batch_size = self.config.dist.tp_cp_allgather(
412412
[can_run_cuda_graph, batch_size])
413413
all_can_graph = all(graph_batch[0]
414414
for graph_batch in graph_batch_size)

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def _fetch_new_requests_attention_dp(
370370
num_active_tokens = sum(
371371
[req.py_orig_prompt_len for req in activate_requests])
372372

373+
# Note: We use tp_allgather even for CP assuming that all CP ranks a
374+
# DP group have the same num_active_tokens and num_active_requests.
373375
responses_list = self.dist.tp_allgather(
374376
[len(activate_requests), num_active_tokens])
375377

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ def get_padded_piecewise_tokens(tokens):
13691369
max(attn_all_rank_num_tokens)
13701370
<= max_captured_num_tokens)
13711371
all_ranks_can_run_piecewise_cuda_graph = list(
1372-
self.dist.tp_allgather(can_run_piecewise_cuda_graph))
1372+
self.dist.tp_cp_allgather(can_run_piecewise_cuda_graph))
13731373
if all(all_ranks_can_run_piecewise_cuda_graph):
13741374
padded_num_tokens = get_padded_piecewise_tokens(
13751375
max(attn_all_rank_num_tokens))
@@ -1536,7 +1536,7 @@ def _prepare_incremental_update_metadata(
15361536
# Handle distributed spec metadata
15371537
if enable_attention_dp:
15381538
sequence_lengths = spec_metadata.seq_lens
1539-
all_rank_num_tokens = self.dist.tp_allgather(
1539+
all_rank_num_tokens = self.dist.tp_cp_allgather(
15401540
[spec_metadata.num_tokens,
15411541
len(sequence_lengths)])
15421542
spec_metadata.all_rank_num_tokens = [
@@ -2691,7 +2691,7 @@ def previous_seq_slots_device():
26912691
inputs['spec_metadata'] = spec_metadata
26922692

26932693
if self.enable_attention_dp:
2694-
all_rank_num_tokens = self.dist.tp_allgather(
2694+
all_rank_num_tokens = self.dist.tp_cp_allgather(
26952695
[spec_metadata.num_tokens,
26962696
len(sequence_lengths)])
26972697

@@ -2856,7 +2856,7 @@ def _prepare_tp_inputs_no_cache(
28562856
# support attention dp
28572857
if self.enable_attention_dp:
28582858
if spec_metadata is not None:
2859-
all_rank_num_tokens = self.dist.tp_allgather([
2859+
all_rank_num_tokens = self.dist.tp_cp_allgather([
28602860
attn_metadata.num_tokens, spec_metadata.num_tokens,
28612861
len(sequence_lengths)
28622862
])
@@ -2871,7 +2871,7 @@ def _prepare_tp_inputs_no_cache(
28712871
spec_metadata.all_rank_num_tokens = spec_all_rank_num_tokens
28722872
spec_metadata.all_rank_num_seqs = all_rank_num_seqs
28732873
else:
2874-
all_rank_num_tokens = self.dist.tp_allgather(
2874+
all_rank_num_tokens = self.dist.tp_cp_allgather(
28752875
attn_metadata.num_tokens)
28762876
attn_metadata.all_rank_num_tokens = all_rank_num_tokens
28772877

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,8 @@ def wait_on_pp_send_handles(self, microbatch_id):
12071207
def _can_queue(self, scheduled_batch):
12081208

12091209
if self.enable_attention_dp:
1210-
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
1210+
tp_batch_sizes = self.dist.tp_cp_allgather(
1211+
scheduled_batch.batch_size)
12111212
can_queue = 0 not in tp_batch_sizes
12121213
else:
12131214
can_queue = scheduled_batch.batch_size > 0
@@ -1552,7 +1553,7 @@ def _executor_loop_overlap(self):
15521553
if self.enable_attention_dp:
15531554
local_can_forward = self.executor_request_queue.num_fetch_requests + \
15541555
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
1555-
all_can_forward = self.dist.tp_allgather(
1556+
all_can_forward = self.dist.tp_cp_allgather(
15561557
local_can_forward)
15571558
if all(all_can_forward):
15581559
can_forward = True
@@ -1924,6 +1925,8 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
19241925
num_scheduled_tokens = sum(
19251926
[len(req.get_tokens(0))
19261927
for req in context_requests]) + num_scheduled_generation_requests
1928+
# Note: We use tp_allgather instead of tp_cp_allgather because we want to
1929+
# balance the requests across DP ranks; not CP ranks within those DP ranks.
19271930
responses_list = self.dist.tp_allgather([
19281931
num_scheduled_context_requests, num_scheduled_generation_requests,
19291932
num_scheduled_tokens

0 commit comments

Comments
 (0)