@@ -960,15 +960,15 @@ def _create_store(self) -> Store:
960960 finish_reasons = int_tensor (self .NEW_TOKENS_SHAPE )
961961
962962 # Only used for logprobs processing or beam search
963- sampled_log_probs = torch .zeros (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .float32 )
963+ sampled_log_probs = torch .empty (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .float32 )
964964 # Only used for logprobs processing
965- sampled_log_prob_indices = torch .zeros (
965+ sampled_log_prob_indices = torch .empty (
966966 self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .int32
967967 )
968- sampled_log_prob_ranks = torch .zeros (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .int32 )
968+ sampled_log_prob_ranks = torch .empty (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .int32 )
969969 # These are 0 sized tensors, if topk-logprobs are not used
970- topk_indices = torch .zeros (self .topk_logprobs_shape , device = "cuda" , dtype = torch .int32 )
971- topk_vals = torch .zeros (self .topk_logprobs_shape , device = "cuda" , dtype = torch .float32 )
970+ topk_indices = torch .empty (self .topk_logprobs_shape , device = "cuda" , dtype = torch .int32 )
971+ topk_vals = torch .empty (self .topk_logprobs_shape , device = "cuda" , dtype = torch .float32 )
972972
973973 # Only used for beam search
974974 cache_indirection : torch .Tensor | None = None
@@ -978,11 +978,11 @@ def _create_store(self) -> Store:
978978 original_tokens : torch .Tensor | None = None
979979 first_finish_reasons : torch .Tensor | None = None
980980 if self ._use_beam_search :
981- cache_indirection = torch .zeros (
981+ cache_indirection = torch .empty (
982982 self .CACHE_INDIRECTION_SHAPE , device = "cuda" , dtype = torch .int
983983 )
984984 cache_indirection_buffer = int_tensor (self .CACHE_INDIRECTION_SHAPE )
985- cum_log_probs = torch .zeros (
985+ cum_log_probs = torch .empty (
986986 self .CACHE_INDIRECTION_SHAPE [:- 1 ], device = "cuda" , dtype = torch .float32
987987 )
988988 predecessor_beams = int_tensor (self .CACHE_INDIRECTION_SHAPE [:- 1 ])
0 commit comments