Skip to content

Commit 1fef88e

Browse files
authored
[None][chore] Improve sampler performance by replacing torch.where with masked_fill_ (#11949)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 81350b7 commit 1fef88e

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,7 +1792,8 @@ def _write_finish_reasons(
17921792
if not single_token_stop_words_only
17931793
else self._are_stop_words_single_token
17941794
)
1795-
batched_finish_reasons[:, stop_word_indices] = torch.where(
1795+
batched_finish_reasons_stop_words = batched_finish_reasons[:, stop_word_indices]
1796+
_ = batched_finish_reasons_stop_words.masked_fill_(
17961797
stop_words_func(
17971798
stop_seq_slots,
17981799
stop_tokens,
@@ -1801,18 +1802,17 @@ def _write_finish_reasons(
18011802
else num_accepted_tokens,
18021803
),
18031804
FinishReason.STOP_WORDS.value,
1804-
batched_finish_reasons[:, stop_word_indices],
18051805
)
1806+
batched_finish_reasons[:, stop_word_indices] = batched_finish_reasons_stop_words
18061807

1807-
batched_finish_reasons = torch.where(
1808+
_ = batched_finish_reasons.masked_fill_(
18081809
self._are_max_length(seq_lens, store.max_lengths_cuda[seq_slots]),
18091810
FinishReason.LENGTH.value,
1810-
batched_finish_reasons,
18111811
)
1812-
batched_finish_reasons = torch.where(
1812+
1813+
_ = batched_finish_reasons.masked_fill_(
18131814
self._are_end_id(store.end_ids_cuda[seq_slots], tokens),
18141815
FinishReason.END_ID.value,
1815-
batched_finish_reasons,
18161816
)
18171817

18181818
finish_reasons[:, seq_slots] = batched_finish_reasons
@@ -1916,7 +1916,7 @@ def _are_stop_words(
19161916
# Fill in the new tokens at the end of the past tokens buffer
19171917
full_tokens[-self._max_tokens :] = tokens
19181918
# short words are padded with _PAD_STOP_WORD_TOKEN_ID, so we need to mask them
1919-
mask = stop_words != self._PAD_STOP_WORD_TOKEN_ID
1919+
mask = stop_words == self._PAD_STOP_WORD_TOKEN_ID
19201920
matches = torch.empty(
19211921
(
19221922
self._max_tokens,
@@ -1941,15 +1941,15 @@ def _are_stop_words(
19411941
stop_words_for_match = stop_words.unsqueeze(0)
19421942
_ = torch.eq(full_tokens_for_match, stop_words_for_match, out=matches)
19431943
# Mask the padding tokens
1944-
matches_after_mask = torch.where(
1945-
mask.unsqueeze(0).expand(self._max_tokens, -1, -1, -1, -1), matches, True
1944+
_ = matches.masked_fill_(
1945+
mask.unsqueeze(0).expand(self._max_tokens, -1, -1, -1, -1), True
19461946
)
19471947
# Update the past tokens storage for the next iteration
19481948
store.past_tokens_cuda[:, seq_slots] = full_tokens
19491949
# Return the result
19501950
word_len_dim = 2
19511951
num_words_dim = 1
1952-
return torch.any(matches_after_mask.all(dim=word_len_dim), dim=num_words_dim)
1952+
return torch.any(matches.all(dim=word_len_dim), dim=num_words_dim)
19531953

19541954
@nvtx_range("_are_stop_words_single_token")
19551955
def _are_stop_words_single_token(
@@ -3721,8 +3721,10 @@ def _sample_batched_by_strategy(
37213721
group_logits_indices_for_processed_logprobs_cuda
37223722
]
37233723
current_softmax_cuda = group_softmax_cuda[logit_indices_for_processed_logprobs_cuda]
3724-
processed_logits_cuda = torch.where(
3725-
current_softmax_cuda > 0, current_logits_cuda, float("-inf")
3724+
3725+
# processed_logits_cuda is an alias to current_logits_cuda after this operation
3726+
processed_logits_cuda = current_logits_cuda.masked_fill_(
3727+
current_softmax_cuda == 0, float("-inf")
37263728
)
37273729
temperature_for_processed_logprobs = group_temperature_cuda
37283730
if isinstance(temperature_for_processed_logprobs, torch.Tensor):

0 commit comments

Comments
 (0)