Skip to content

Commit ba62872

Browse files
committed
Precalculate offset tensor only once and reuse it in _are_max_length
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent cda092d commit ba62872

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -980,8 +980,8 @@ def _create_store(self) -> Store:
980980
# Tensors necessary for all sampling methods
981981
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
982982
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
983-
max_lengths_tensor=int_tensor(self.max_num_sequences)
984-
end_ids=int_tensor(self.max_num_sequences)
983+
max_lengths_tensor = int_tensor(self.max_num_sequences)
984+
end_ids = int_tensor(self.max_num_sequences)
985985

986986
# Only used for logprobs processing or beam search
987987
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
@@ -1082,6 +1082,9 @@ def __init__(self, args: Args):
10821082
FinishReason.CANCELLED,
10831083
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
10841084
}
1085+
self._max_tokens_offset = torch.arange(
1086+
1, self.max_tokens + 1, device="cuda", dtype=torch.int32
1087+
).view(-1, 1, 1)
10851088

10861089
self._grouped_sampler_cls: Type[GroupedStrategySampler]
10871090
if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling:
@@ -2864,12 +2867,9 @@ def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) ->
28642867
A tensor of shape (max_tokens, len(requests), max_beam_width)
28652868
where each element is True if the sequence is at or beyond the max length, False otherwise
28662869
"""
2867-
lengths_tensor = (
2868-
seq_lens.view(1, -1, 1)
2869-
+ torch.arange(
2870-
1, self.max_tokens + 1, device=seq_lens.device, dtype=seq_lens.dtype
2871-
).view(-1, 1, 1)
2872-
).expand(self.max_tokens, -1, self.max_beam_width)
2870+
lengths_tensor = (seq_lens.view(1, -1, 1) + self._max_tokens_offset).expand(
2871+
self.max_tokens, -1, self.max_beam_width
2872+
)
28732873
max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand(
28742874
self.max_tokens, -1, self.max_beam_width
28752875
)

0 commit comments

Comments
 (0)