Skip to content

Commit 2c6e567

Browse files
committed
Fix return_top_n negative infinity bug
1 parent 642041d commit 2c6e567

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

server/text_generation_server/utils/tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_token_info(
396396
top_n = min(return_top_n, flat_scores.size(-1))
397397
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
398398
nth_highest = flat_scores.topk(top_n).values[-1]
399-
torch.nan_to_num_(nth_highest, neginf=torch.finfo(torch.float).min)
399+
torch.nan_to_num_(nth_highest, neginf=torch.finfo(flat_scores.dtype).min)
400400
# Get indices (token ids) of all scores >= nth highest value,
401401
# cap length at 4 * top_n as a precaution
402402
top_n_indices = (flat_scores >= nth_highest).nonzero().squeeze(-1)[:(top_n * 4)]

0 commit comments

Comments
 (0)