Skip to content

Commit 54b5acb

Browse files
committed
[TRTLLM-9687][perf] Improve performance of _are_end_id
- Introduced `end_ids` tensor in the `Store` class to store end IDs for each request. - Updated `setup_sampler_step` to fill `end_ids` based on request parameters. - Refactored `_are_end_id` method to utilize the new `end_ids` tensor for better performance. Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 536573f commit 54b5acb

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,9 @@ class Store:
932932
max_lengths_tensor: torch.Tensor
933933
"""Shape: batch_size
934934
Usage: Stores the maximum lengths for each request"""
935+
end_ids: torch.Tensor
936+
"""Shape: batch_size
937+
Usage: Stores the end ids for each request"""
935938
finish_reasons: torch.Tensor
936939
"""Shape: max_tokens, batch_size, beam_width
937940
Usage: Stores the currently estimated finish_reasons for each request"""
@@ -977,7 +980,8 @@ def _create_store(self) -> Store:
977980
# Tensors necessary for all sampling methods
978981
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
979982
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
980-
max_lengths_tensor=int_tensor(self.max_num_sequences),
983+
max_lengths_tensor=int_tensor(self.max_num_sequences)
984+
end_ids=int_tensor(self.max_num_sequences)
981985

982986
# Only used for logprobs processing or beam search
983987
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
@@ -1012,6 +1016,7 @@ def _create_store(self) -> Store:
10121016
new_tokens=new_tokens,
10131017
finish_reasons=finish_reasons,
10141018
max_lengths_tensor=max_lengths_tensor,
1019+
end_ids=end_ids,
10151020
cache_indirection=cache_indirection,
10161021
cache_indirection_buffer=cache_indirection_buffer,
10171022
cum_log_probs=cum_log_probs,
@@ -1549,6 +1554,9 @@ def setup_sampler_step(self, requests: ScheduledRequests):
15491554
self.store.max_lengths_tensor[request.py_seq_slot].fill_(
15501555
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
15511556
)
1557+
self.store.end_ids[request.py_seq_slot].fill_(
1558+
request.py_end_id if request.py_end_id is not None else -1
1559+
)
15521560

15531561
def _prepare_beam_search(
15541562
self,
@@ -2822,7 +2830,7 @@ def _write_finish_reasons(
28222830
batched_finish_reasons,
28232831
)
28242832
batched_finish_reasons = torch.where(
2825-
self._are_end_id(requests, tokens),
2833+
self._are_end_id(self.store.end_ids[seq_slots], tokens),
28262834
self._reason_tensors[FinishReason.END_ID],
28272835
batched_finish_reasons,
28282836
)
@@ -2838,21 +2846,8 @@ def _write_finish_reasons(
28382846
)
28392847
first_finish_reasons[seq_slots] = batched_first_finish_reasons
28402848

2841-
def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor:
2842-
end_ids_tensor = (
2843-
torch.tensor(
2844-
[
2845-
([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width)
2846-
for req in requests
2847-
]
2848-
* self.max_tokens,
2849-
pin_memory=True,
2850-
dtype=tokens.dtype,
2851-
)
2852-
.view(self.max_tokens, len(requests), self.max_beam_width)
2853-
.to(device="cuda", non_blocking=True)
2854-
)
2855-
return tokens == end_ids_tensor
2849+
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
2850+
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)
28562851

28572852
def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor:
28582853
"""Checks which sequences are at or beyond the max length

0 commit comments

Comments
 (0)