Skip to content

Commit 25e974e

Browse files
committed
remove
1 parent 7e230e5 commit 25e974e

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

src/transformers/generation/utils.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,17 +3727,9 @@ def _get_top_k_continuations(
37273727

37283728
# Gather the top K scores from _all_ beams.
37293729
if do_sample:
3730-
# Handle potential NaN values in accumulated_log_probs
3731-
probs = nn.functional.softmax(accumulated_log_probs, dim=-1)
3732-
# Replace NaN values with uniform distribution
3733-
if torch.isnan(probs).any():
3734-
# Create a mask for NaN positions
3735-
nan_mask = torch.isnan(probs)
3736-
# Replace NaN with a small uniform probability
3737-
probs = torch.where(nan_mask, torch.ones_like(probs) / probs.shape[-1], probs)
3738-
# Renormalize to ensure probabilities sum to 1
3739-
probs = probs / probs.sum(dim=-1, keepdim=True)
3740-
topk_indices = torch.multinomial(probs, num_samples=beams_to_keep)
3730+
topk_indices = torch.multinomial(
3731+
nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
3732+
)
37413733
topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
37423734
else:
37433735
topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)

0 commit comments

Comments
 (0)