@@ -932,6 +932,9 @@ class Store:
932932 max_lengths_tensor : torch .Tensor
933933 """Shape: batch_size
934934 Usage: Stores the maximum lengths for each request"""
935+ end_ids : torch .Tensor
936+ """Shape: batch_size
937+ Usage: Stores the end ids for each request"""
935938 finish_reasons : torch .Tensor
936939 """Shape: max_tokens, batch_size, beam_width
937940 Usage: Stores the currently estimated finish_reasons for each request"""
@@ -977,7 +980,8 @@ def _create_store(self) -> Store:
977980 # Tensors necessary for all sampling methods
978981 new_tokens = int_tensor (self .NEW_TOKENS_SHAPE )
979982 finish_reasons = int_tensor (self .NEW_TOKENS_SHAPE )
980- max_lengths_tensor = int_tensor (self .max_num_sequences ),
983+ max_lengths_tensor = int_tensor (self .max_num_sequences )
984+ end_ids = int_tensor (self .max_num_sequences )
981985
982986 # Only used for logprobs processing or beam search
983987 sampled_log_probs = torch .empty (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .float32 )
@@ -1012,6 +1016,7 @@ def _create_store(self) -> Store:
10121016 new_tokens = new_tokens ,
10131017 finish_reasons = finish_reasons ,
10141018 max_lengths_tensor = max_lengths_tensor ,
1019+ end_ids = end_ids ,
10151020 cache_indirection = cache_indirection ,
10161021 cache_indirection_buffer = cache_indirection_buffer ,
10171022 cum_log_probs = cum_log_probs ,
@@ -1549,6 +1554,9 @@ def setup_sampler_step(self, requests: ScheduledRequests):
15491554 self .store .max_lengths_tensor [request .py_seq_slot ].fill_ (
15501555 min (self .max_seq_len , request .orig_prompt_len + request .py_max_new_tokens )
15511556 )
1557+ self .store .end_ids [request .py_seq_slot ].fill_ (
1558+ request .py_end_id if request .py_end_id is not None else - 1
1559+ )
15521560
15531561 def _prepare_beam_search (
15541562 self ,
@@ -2822,7 +2830,7 @@ def _write_finish_reasons(
28222830 batched_finish_reasons ,
28232831 )
28242832 batched_finish_reasons = torch .where (
2825- self ._are_end_id (requests , tokens ),
2833+ self ._are_end_id (self . store . end_ids [ seq_slots ] , tokens ),
28262834 self ._reason_tensors [FinishReason .END_ID ],
28272835 batched_finish_reasons ,
28282836 )
@@ -2838,21 +2846,8 @@ def _write_finish_reasons(
28382846 )
28392847 first_finish_reasons [seq_slots ] = batched_first_finish_reasons
28402848
2841- def _are_end_id (self , requests : list [LlmRequest ], tokens : torch .Tensor ) -> torch .Tensor :
2842- end_ids_tensor = (
2843- torch .tensor (
2844- [
2845- ([req .py_end_id if req .py_end_id is not None else - 1 ] * self .max_beam_width )
2846- for req in requests
2847- ]
2848- * self .max_tokens ,
2849- pin_memory = True ,
2850- dtype = tokens .dtype ,
2851- )
2852- .view (self .max_tokens , len (requests ), self .max_beam_width )
2853- .to (device = "cuda" , non_blocking = True )
2854- )
2855- return tokens == end_ids_tensor
2849+ def _are_end_id (self , end_ids : torch .Tensor , tokens : torch .Tensor ) -> torch .Tensor :
2850+ return tokens == end_ids .view (1 , - 1 , 1 ).expand (self .max_tokens , - 1 , self .max_beam_width )
28562851
28572852 def _are_max_length (self , seq_lens : torch .Tensor , max_seq_lens : torch .Tensor ) -> torch .Tensor :
28582853 """Checks which sequences are at or beyond the max length
0 commit comments