@@ -980,8 +980,8 @@ def _create_store(self) -> Store:
980980 # Tensors necessary for all sampling methods
981981 new_tokens = int_tensor (self .NEW_TOKENS_SHAPE )
982982 finish_reasons = int_tensor (self .NEW_TOKENS_SHAPE )
983- max_lengths_tensor = int_tensor (self .max_num_sequences )
984- end_ids = int_tensor (self .max_num_sequences )
983+ max_lengths_tensor = int_tensor (self .max_num_sequences )
984+ end_ids = int_tensor (self .max_num_sequences )
985985
986986 # Only used for logprobs processing or beam search
987987 sampled_log_probs = torch .empty (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .float32 )
@@ -1082,6 +1082,9 @@ def __init__(self, args: Args):
10821082 FinishReason .CANCELLED ,
10831083 ] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
10841084 }
1085+ self ._max_tokens_offset = torch .arange (
1086+ 1 , self .max_tokens + 1 , device = "cuda" , dtype = torch .int32
1087+ ).view (1 , 1 , - 1 )
10851088
10861089 self ._grouped_sampler_cls : Type [GroupedStrategySampler ]
10871090 if IS_FLASHINFER_AVAILABLE and not args .disable_flashinfer_sampling :
@@ -2864,12 +2867,9 @@ def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) ->
28642867 A tensor of shape (max_tokens, len(requests), max_beam_width)
28652868 where each element is True if the sequence is at or beyond the max length, False otherwise
28662869 """
2867- lengths_tensor = (
2868- seq_lens .view (1 , - 1 , 1 )
2869- + torch .arange (
2870- 1 , self .max_tokens + 1 , device = seq_lens .device , dtype = seq_lens .dtype
2871- ).view (- 1 , 1 , 1 )
2872- ).expand (self .max_tokens , - 1 , self .max_beam_width )
2870+ lengths_tensor = (seq_lens .view (1 , - 1 , 1 ) + self ._max_tokens_offset ).expand (
2871+ self .max_tokens , - 1 , self .max_beam_width
2872+ )
28732873 max_lengths_tensor = max_seq_lens .view (1 , - 1 , 1 ).expand (
28742874 self .max_tokens , - 1 , self .max_beam_width
28752875 )
0 commit comments