Skip to content

Commit 7e230e5

Browse files
committed
fix
1 parent 9ca43c6 commit 7e230e5

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/transformers/generation/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,9 +3727,17 @@ def _get_top_k_continuations(
37273727

37283728
# Gather the top K scores from _all_ beams.
37293729
if do_sample:
3730-
topk_indices = torch.multinomial(
3731-
nn.functional.softmax(accumulated_log_probs, dim=-1), num_samples=beams_to_keep
3732-
)
3730+
# Handle potential NaN values in accumulated_log_probs
3731+
probs = nn.functional.softmax(accumulated_log_probs, dim=-1)
3732+
# Replace NaN values with uniform distribution
3733+
if torch.isnan(probs).any():
3734+
# Create a mask for NaN positions
3735+
nan_mask = torch.isnan(probs)
3736+
# Replace NaN with a small uniform probability
3737+
probs = torch.where(nan_mask, torch.ones_like(probs) / probs.shape[-1], probs)
3738+
# Renormalize to ensure probabilities sum to 1
3739+
probs = probs / probs.sum(dim=-1, keepdim=True)
3740+
topk_indices = torch.multinomial(probs, num_samples=beams_to_keep)
37333741
topk_log_probs = torch.gather(input=accumulated_log_probs, dim=1, index=topk_indices)
37343742
else:
37353743
topk_log_probs, topk_indices = torch.topk(accumulated_log_probs, k=beams_to_keep)

0 commit comments

Comments
 (0)