Skip to content

Commit d1ab93c

Browse files
committed
revert tp_cp_allgather in multiple places
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent ed0d016 commit d1ab93c

File tree

4 files changed

+7
-19
lines changed

4 files changed

+7
-19
lines changed

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
133133

134134
if (worldConfig.isTensorParallel())
135135
{
136-
if (worldConfig.isContextParallel())
137-
{
138-
// When CP is enabled, group ranks with same (ppRank, cpRank) to exclude both PP and CP.
139-
auto const tpGroupId = worldConfig.getContextParallelRank()
140-
+ worldConfig.getContextParallelism() * worldConfig.getPipelineParallelRank();
141-
mGroupTensorParaComm
142-
= std::make_shared<CacheTransceiverComm>(mGroupComm->split(tpGroupId, worldConfig.getRank()));
143-
}
144-
else
145-
{
146-
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
147-
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
148-
}
136+
mGroupTensorParaComm = std::make_shared<CacheTransceiverComm>(
137+
mGroupComm->split(worldConfig.getPipelineParallelRank(), worldConfig.getTensorParallelRank()));
149138
}
150139
int kvFactor = 2;
151140
if (cacheManager->getCacheType() == kv_cache_manager::CacheType::kSELFKONLY)

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def maybe_get_cuda_graph(
247247
can_run_cuda_graph = batch.can_run_cuda_graph
248248
batch_size = batch.batch_size
249249
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
250-
all_can_graph_batch = self.config.dist.tp_cp_allgather(
250+
all_can_graph_batch = self.config.dist.tp_allgather(
251251
[can_run_cuda_graph, batch_size])
252252
is_all_gen_only = all(all_can_graph[0]
253253
for all_can_graph in all_can_graph_batch)
@@ -409,7 +409,7 @@ def _get_padded_batch(self, batch: ScheduledRequests,
409409
new_batch_size = batch_size
410410

411411
if self.enabled and self.config.enable_attention_dp and self.config.mapping.tp_size > 1:
412-
graph_batch_size = self.config.dist.tp_cp_allgather(
412+
graph_batch_size = self.config.dist.tp_allgather(
413413
[can_run_cuda_graph, batch_size])
414414
all_can_graph = all(graph_batch[0]
415415
for graph_batch in graph_batch_size)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ def get_padded_piecewise_tokens(tokens):
13691369
max(attn_all_rank_num_tokens)
13701370
<= max_captured_num_tokens)
13711371
all_ranks_can_run_piecewise_cuda_graph = list(
1372-
self.dist.tp_cp_allgather(can_run_piecewise_cuda_graph))
1372+
self.dist.tp_allgather(can_run_piecewise_cuda_graph))
13731373
if all(all_ranks_can_run_piecewise_cuda_graph):
13741374
padded_num_tokens = get_padded_piecewise_tokens(
13751375
max(attn_all_rank_num_tokens))

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,8 +1227,7 @@ def wait_on_pp_send_handles(self, microbatch_id):
12271227
def _can_queue(self, scheduled_batch):
12281228

12291229
if self.enable_attention_dp:
1230-
tp_batch_sizes = self.dist.tp_cp_allgather(
1231-
scheduled_batch.batch_size)
1230+
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
12321231
can_queue = 0 not in tp_batch_sizes
12331232
else:
12341233
can_queue = scheduled_batch.batch_size > 0
@@ -1573,7 +1572,7 @@ def _executor_loop_overlap(self):
15731572
if self.enable_attention_dp:
15741573
local_can_forward = self.executor_request_queue.num_fetch_requests + \
15751574
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
1576-
all_can_forward = self.dist.tp_cp_allgather(
1575+
all_can_forward = self.dist.tp_allgather(
15771576
local_can_forward)
15781577
if all(all_can_forward):
15791578
can_forward = True

0 commit comments

Comments
 (0)