File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
src/transformers/generation Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -3727,9 +3727,17 @@ def _get_top_k_continuations(
3727
3727
3728
3728
# Gather the top K scores from _all_ beams.
3729
3729
if do_sample :
3730
- topk_indices = torch .multinomial (
3731
- nn .functional .softmax (accumulated_log_probs , dim = - 1 ), num_samples = beams_to_keep
3732
- )
3730
+ # Handle potential NaN values in accumulated_log_probs
3731
+ probs = nn .functional .softmax (accumulated_log_probs , dim = - 1 )
3732
+ # Replace NaN values with uniform distribution
3733
+ if torch .isnan (probs ).any ():
3734
+ # Create a mask for NaN positions
3735
+ nan_mask = torch .isnan (probs )
3736
+ # Replace NaN with a small uniform probability
3737
+ probs = torch .where (nan_mask , torch .ones_like (probs ) / probs .shape [- 1 ], probs )
3738
+ # Renormalize to ensure probabilities sum to 1
3739
+ probs = probs / probs .sum (dim = - 1 , keepdim = True )
3740
+ topk_indices = torch .multinomial (probs , num_samples = beams_to_keep )
3733
3741
topk_log_probs = torch .gather (input = accumulated_log_probs , dim = 1 , index = topk_indices )
3734
3742
else :
3735
3743
topk_log_probs , topk_indices = torch .topk (accumulated_log_probs , k = beams_to_keep )
You can’t perform that action at this time.
0 commit comments