@@ -202,12 +202,13 @@ def __init__(
202
202
warpers .append (HeterogeneousTopKLogitsWarper (top_k , device ))
203
203
204
204
if any (x < 1.0 for x in top_p ):
205
+ #assert all(x != 0 for x in top_p)
205
206
warpers .append (HeterogeneousTopPLogitsWarper (top_p , dtype , device ))
206
- # We specifically exclude degenerate case 0, we devolves into greedy decoding,
207
- # to align with the nonvectorized typical logit warping behavior.
208
- if any ( 0.0 < x < 1. 0 for x in typical_p ):
209
- corrected_probs = [ p if p != 0 else 1 for p in typical_p ]
210
- warpers . append ( HeterogeneousTypicalLogitsWarper ( corrected_probs , dtype , device ))
207
+
208
+ if any ( x < 1.0 for x in typical_p ):
209
+ #assert all(x != 0 for x in typical_p)
210
+ warpers . append ( HeterogeneousTypicalLogitsWarper ( typical_p , dtype , device ))
211
+
211
212
self .choice = HeterogeneousSampling (do_sample , seeds , device )
212
213
else :
213
214
self .choice = Greedy ()
@@ -279,8 +280,10 @@ def from_pb(
279
280
temperature = [pb_ .temperature for pb_ in pb ],
280
281
repetition_penalty = [pb_ .repetition_penalty if pb_ .HasField ('repetition_penalty' ) else 1.0 for pb_ in pb ],
281
282
top_k = [pb_ .top_k for pb_ in pb ],
282
- top_p = [pb_ .top_p for pb_ in pb ],
283
- typical_p = [pb_ .typical_p for pb_ in pb ],
283
+ # Ensure that default (zero) values for top_p and typical_p are converted to 1.0
284
+ # (which corresponds to disabled in both cases)
285
+ top_p = [pb_ .top_p if pb_ .top_p > 0 else 1.0 for pb_ in pb ],
286
+ typical_p = [pb_ .typical_p if pb_ .typical_p > 0 else 1.0 for pb_ in pb ],
284
287
length_penalty = [
285
288
(pb_ .length_penalty .start_index , pb_ .length_penalty .decay_factor )
286
289
if pb_ .HasField ('length_penalty' ) else None for pb_ in pb
@@ -310,7 +313,7 @@ def filter(self, indices):
310
313
self .current_tokens = [self .current_tokens [i ] for i in indices ]
311
314
self .min_new_tokens = [self .min_new_tokens [i ] for i in indices ]
312
315
self .length_penalty = [self .length_penalty [i ] for i in indices ]
313
- self .return_logprobs = [self .return_logprobs [i ] for i in indices ]
316
+ self .return_logprobs = [self .return_logprobs [i ] for i in indices ]
314
317
315
318
if any (self .do_sample ):
316
319
self .choice .filter (indices )
@@ -370,12 +373,13 @@ def filter(self, indices):
370
373
self .samplings = [self .samplings [i ] for i in indices ]
371
374
return self
372
375
376
+
373
377
# Extract requested token information from model output
374
378
def get_token_info (
375
379
request : generate_pb2 .Request ,
376
380
scores : torch .Tensor , # Assumes shape is [1, vocab_size]
377
381
next_token : torch .Tensor ,
378
- logprobs : Optional [torch .Tensor ], # Assumes shape matches logits
382
+ logprobs : Optional [torch .Tensor ], # Assumes shape matches logits
379
383
) -> TokenInfo :
380
384
next_token = next_token .item ()
381
385
token_info = TokenInfo (request_id = request .id , token_id = next_token )
@@ -391,9 +395,8 @@ def get_token_info(
391
395
# Ensure top_n doesn't exceed vocab size
392
396
top_n = min (return_top_n , flat_scores .size (- 1 ))
393
397
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
394
- nth_highest = torch .topk (flat_scores , top_n )[0 ][- 1 ]
395
- if nth_highest == - float ('inf' ):
396
- nth_highest = torch .finfo (flat_scores .dtype ).min
398
+ nth_highest = flat_scores .topk (top_n ).values [- 1 ]
399
+ torch .nan_to_num_ (nth_highest , neginf = torch .finfo (torch .float ).min )
397
400
# Get indices (token ids) of all scores >= nth highest value,
398
401
# cap length at 4 * top_n as a precaution
399
402
top_n_indices = (flat_scores >= nth_highest ).nonzero ().squeeze (- 1 )[:(top_n * 4 )]
@@ -407,7 +410,7 @@ def get_token_info(
407
410
# Token ranks if requested
408
411
if request .details .ranks :
409
412
#TODO if we're also returning top_n perhaps search those first
410
- token_info .rank = (scores > scores [0 ][ next_token ]).sum () + 1
413
+ token_info .rank = (scores > scores [0 , next_token ]).sum () + 1
411
414
412
415
return token_info
413
416
@@ -447,7 +450,7 @@ def get_input_tokens_info(request, input_token_ids, all_input_logits) -> InputTo
447
450
# Ensure top_n doesn't exceed vocab size
448
451
top_n = min (top_n , all_input_logits .size (- 1 ))
449
452
# Get the nth highest value for each input token's set of logits
450
- nth_highest_values = torch .topk (all_input_logits , top_n )[ 0 ] [..., - 1 , None ]
453
+ nth_highest_values = torch .topk (all_input_logits , top_n ). values [..., - 1 , None ]
451
454
# Construct bool tensor marking all scores >= nth highest value for each token
452
455
diff = (all_input_logits >= nth_highest_values )
453
456
# Gather set of marked indices for each token (correspond to top token ids)
0 commit comments