We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 642041d commit 2c6e567Copy full SHA for 2c6e567
server/text_generation_server/utils/tokens.py
@@ -396,7 +396,7 @@ def get_token_info(
396
top_n = min(return_top_n, flat_scores.size(-1))
397
# Get nth highest value, ensure it's not -inf (for example if top_n > top_k)
398
nth_highest = flat_scores.topk(top_n).values[-1]
399
- torch.nan_to_num_(nth_highest, neginf=torch.finfo(torch.float).min)
+ torch.nan_to_num_(nth_highest, neginf=torch.finfo(flat_scores.dtype).min)
400
# Get indices (token ids) of all scores >= nth highest value,
401
# cap length at 4 * top_n as a precaution
402
top_n_indices = (flat_scores >= nth_highest).nonzero().squeeze(-1)[:(top_n * 4)]
0 commit comments