Skip to content

Commit e224184

Browse files
committed
Revert "[TRTLLM-5972][chore] Load balance decode token KV cache with helix parallelism"
This reverts commit 6b60df1.
1 parent 28355cc commit e224184

File tree

3 files changed

+10
-15
lines changed

3 files changed

+10
-15
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,6 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
694694
position_ids=position_ids_this_rank,
695695
)
696696
req.total_input_len_cp = input_len
697-
req.seqlen_this_rank_cp = len(input_ids_this_rank)
698697
req_with_children.append(req)
699698
if req.child_requests:
700699
req_with_children.extend(req.child_requests)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,12 +1671,12 @@ def _prepare_tp_inputs(
16711671
# Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
16721672
if not self.is_warmup and not request.is_cuda_graph_dummy:
16731673
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
1674-
if request.py_helix_is_inactive_rank:
1675-
past_seen_token_num = request.seqlen_this_rank_cp
1674+
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
1675+
if self.mapping.cp_rank == self.mapping.cp_size - 1:
1676+
past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1
16761677
else:
1677-
# Discount the token added to active rank in resource manager as it hasn't
1678-
# been previously seen.
1679-
past_seen_token_num = request.seqlen_this_rank_cp - 1
1678+
# past_seen_token_num doesn't grow on inactive ranks.
1679+
past_seen_token_num = request.orig_prompt_len
16801680

16811681
position_ids.append(position_id)
16821682
num_cached_tokens_per_seq.append(past_seen_token_num)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -468,17 +468,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
468468
req, block_ids)
469469

470470
for req in generation_batch:
471+
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
471472
if self.mapping.has_cp_helix():
472-
# Distribute the decode blocks across CP ranks in a round-robin manner.
473-
decode_block_id = (req.py_decoding_iter -
474-
1) // self.tokens_per_block
475-
if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank:
476-
req.py_helix_is_inactive_rank = False
477-
req.seqlen_this_rank_cp += 1
478-
else:
473+
if self.mapping.cp_rank != self.mapping.cp_size - 1:
479474
req.py_helix_is_inactive_rank = True
480-
# Skip allocating KV cache at decode for inactive helix ranks.
481-
continue
475+
# Skip allocating KV cache at decode for inactive helix ranks.
476+
if req.py_helix_is_inactive_rank:
477+
continue
482478
self.impl.add_token(req.py_request_id)
483479
for _ in range(get_draft_token_length(req)):
484480
self.impl.add_token(req.py_request_id)

0 commit comments

Comments
 (0)