@@ -198,27 +198,63 @@ def inner(
198198 seqno : int , outputs : dict [str , Tensor ], attention_mask : Tensor
199199 ) -> dict [str , Tensor ]:
200200 next_token_logits = outputs ["logits" ][:, - 1 , :]
201+ batch_size = next_token_logits .size (0 )
201202 device = next_token_logits .device
202203
203- if self .do_sample :
204- # Apply temperature
205- next_token_logits = next_token_logits / self .temperature
204+ state = self .generated_tokens .get (seqno )
205+ if state is None :
206+ state = {
207+ "tokens" : [[] for _ in range (batch_size )],
208+ "finished" : [False ] * batch_size ,
209+ }
210+ self .generated_tokens [seqno ] = state
206211
207- # Apply top_p (nucleus) sampling
208- sorted_logits , sorted_indices = torch . sort (
209- next_token_logits , descending = True
210- )
211- cumulative_probs = torch . cumsum (
212- torch . nn . functional . softmax ( sorted_logits , dim = - 1 ), dim = - 1
212+ gen_tokens = state [ "tokens" ]
213+ finished = state [ "finished" ]
214+
215+ if len ( gen_tokens ) != batch_size :
216+ raise ValueError (
217+ "Mismatched batch size between cached tokens and new logits"
213218 )
214219
215- sorted_indices_to_remove = cumulative_probs > self .top_p
216- sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [
217- ..., :- 1
218- ].clone ()
219- sorted_indices_to_remove [..., 0 ] = 0
220- indices_to_remove = sorted_indices [sorted_indices_to_remove ]
221- next_token_logits [:, indices_to_remove ] = float ("-inf" )
220+ if self .do_sample :
221+ if self .temperature > 0.0 :
222+ # Apply temperature
223+ next_token_logits = next_token_logits / self .temperature
224+ else :
225+ self .do_sample = False
226+ logger .debug (f"temperature is 0.0, switching to greedy decoding" )
227+
228+ eos_token_id = self .eos_token_id
229+
230+ for idx , is_finished in enumerate (finished ):
231+ if is_finished :
232+ next_token_logits [idx ] = float ("-inf" )
233+ next_token_logits [idx , eos_token_id ] = 0.0
234+
235+ if self .do_sample :
236+ if 0.0 < self .top_p < 1.0 :
237+ sorted_logits , sorted_indices = torch .sort (
238+ next_token_logits , descending = True
239+ )
240+ cumulative_probs = torch .cumsum (
241+ torch .nn .functional .softmax (sorted_logits , dim = - 1 ), dim = - 1
242+ )
243+ sorted_indices_to_remove = cumulative_probs > self .top_p
244+ sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [
245+ ..., :- 1
246+ ].clone ()
247+ sorted_indices_to_remove [..., 0 ] = False
248+
249+ for batch_idx in range (batch_size ):
250+ if finished [batch_idx ]:
251+ continue
252+ indices_to_remove = sorted_indices [batch_idx ][
253+ sorted_indices_to_remove [batch_idx ]
254+ ]
255+ next_token_logits [batch_idx , indices_to_remove ] = float (
256+ "-inf"
257+ )
222258
223259 # Sample from the filtered distribution
224260 next_token_probs = torch .nn .functional .softmax (
@@ -229,20 +265,31 @@ def inner(
229265 # Greedy decoding
230266 next_token = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
231267
232- if seqno not in self .generated_tokens :
233- self .generated_tokens [seqno ] = [[] for _ in range (len (next_token ))]
234- gen_tokens = self .generated_tokens [seqno ]
268+ for i in range (batch_size ):
269+ if finished [i ]:
270+ next_token [i , 0 ] = eos_token_id
271+ continue
272+
273+ token_id = next_token [i , 0 ].item ()
274+ gen_tokens [i ].append (token_id )
275+
276+ if token_id == eos_token_id or len (gen_tokens [i ]) >= self .max_new_tokens :
277+ finished [i ] = True
278+ if token_id != eos_token_id :
279+ next_token [i , 0 ] = eos_token_id
280+
281+ if all (finished ):
282+ max_length = max (len (tokens ) for tokens in gen_tokens )
283+ padded_tokens = []
284+ for tokens in gen_tokens :
285+ if len (tokens ) < max_length :
286+ tokens = tokens + [eos_token_id ] * (max_length - len (tokens ))
287+ padded_tokens .append (tokens )
235288
236- for i , token in enumerate (next_token ):
237- gen_tokens [i ].append (token .item ())
289+ tensor = torch .tensor (padded_tokens , dtype = torch .int64 , device = device )
290+ if tensor .size (0 ) == 1 :
291+ tensor = tensor [0 ]
238292
239- # Check for EOS token or if max number of tokens are generated
240- if (
241- self .max_new_tokens == len (gen_tokens [0 ])
242- or next_token [0 ].item () == self .eos_token_id
243- ):
244- gen_tokens = gen_tokens if len (gen_tokens ) > 1 else gen_tokens [0 ]
245- tensor = torch .tensor (gen_tokens , dtype = torch .int64 , device = device )
246293 del self .generated_tokens [seqno ]
247294
248295 return {"tokens" : tensor }
0 commit comments