Skip to content

Commit d1ad962

Browse files
committed
simplify further
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 799a339 commit d1ad962

File tree

1 file changed

+7
-24
lines changed

1 file changed

+7
-24
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -362,33 +362,16 @@ def _fetch_new_requests_attention_dp(
362362
# Get active request counts across all ranks.
363363
all_ranks_num_active_requests = []
364364
all_ranks_num_active_tokens = []
365-
num_active_tokens = sum(
366-
[req.py_orig_prompt_len for req in activate_requests])
367365

368366
if self.dist.has_cp_helix:
369-
# When CP is enabled with Helix parallelism, we need to gather from all ranks
370-
# in the TP x CP space. CP ranks within the same DP group (same tp_rank) handle
371-
# the same requests with different token portions (sequence is split across CP ranks).
372-
responses_list = self.dist.tp_cp_allgather(
373-
[len(activate_requests), num_active_tokens])
374-
375-
aggregated_responses = []
376-
for dp_group_idx in range(self.dist.tp_size):
377-
# Get all entries for this DP group (cp_size entries per group).
378-
group_start = dp_group_idx * self.dist.cp_size
379-
group_end = (dp_group_idx + 1) * self.dist.cp_size
380-
group_entries = responses_list[group_start:group_end]
381-
382-
# All CP ranks within a DP group should have the same number of requests.
383-
assert all(entry[0] == group_entries[0][0] for entry in group_entries), \
384-
f"CP ranks within DP group {dp_group_idx} have mismatched request counts: " \
385-
f"{[entry[0] for entry in group_entries]}"
386-
# Use token count from cp_rank0.
387-
aggregated_responses.append(group_entries[0])
388-
responses_list = aggregated_responses
367+
num_active_tokens = sum(
368+
[req.total_input_len_cp for req in activate_requests])
389369
else:
390-
responses_list = self.dist.tp_allgather(
391-
[len(activate_requests), num_active_tokens])
370+
num_active_tokens = sum(
371+
[req.py_orig_prompt_len for req in activate_requests])
372+
373+
responses_list = self.dist.tp_allgather(
374+
[len(activate_requests), num_active_tokens])
392375

393376
for num_active_requests, num_active_tokens in responses_list:
394377
all_ranks_num_active_requests.append(num_active_requests)

0 commit comments

Comments
 (0)