Skip to content

Commit c7d6f4f

Browse files
Avoid unnecessary initialization for runtime buffers
Signed-off-by: Yuan Tong <[email protected]>
1 parent 348a547 commit c7d6f4f

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -960,15 +960,15 @@ def _create_store(self) -> Store:
960960
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
961961

962962
# Only used for logprobs processing or beam search
963-
sampled_log_probs = torch.zeros(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
963+
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
964964
# Only used for logprobs processing
965-
sampled_log_prob_indices = torch.zeros(
965+
sampled_log_prob_indices = torch.empty(
966966
self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32
967967
)
968-
sampled_log_prob_ranks = torch.zeros(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32)
968+
sampled_log_prob_ranks = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32)
969969
# These are 0 sized tensors, if topk-logprobs are not used
970-
topk_indices = torch.zeros(self.topk_logprobs_shape, device="cuda", dtype=torch.int32)
971-
topk_vals = torch.zeros(self.topk_logprobs_shape, device="cuda", dtype=torch.float32)
970+
topk_indices = torch.empty(self.topk_logprobs_shape, device="cuda", dtype=torch.int32)
971+
topk_vals = torch.empty(self.topk_logprobs_shape, device="cuda", dtype=torch.float32)
972972

973973
# Only used for beam search
974974
cache_indirection: torch.Tensor | None = None
@@ -978,11 +978,11 @@ def _create_store(self) -> Store:
978978
original_tokens: torch.Tensor | None = None
979979
first_finish_reasons: torch.Tensor | None = None
980980
if self._use_beam_search:
981-
cache_indirection = torch.zeros(
981+
cache_indirection = torch.empty(
982982
self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int
983983
)
984984
cache_indirection_buffer = int_tensor(self.CACHE_INDIRECTION_SHAPE)
985-
cum_log_probs = torch.zeros(
985+
cum_log_probs = torch.empty(
986986
self.CACHE_INDIRECTION_SHAPE[:-1], device="cuda", dtype=torch.float32
987987
)
988988
predecessor_beams = int_tensor(self.CACHE_INDIRECTION_SHAPE[:-1])

0 commit comments

Comments
 (0)