File tree Expand file tree Collapse file tree 1 file changed +3
-11
lines changed
src/transformers/generation Expand file tree Collapse file tree 1 file changed +3
-11
lines changed Original file line number Diff line number Diff line change @@ -3727,17 +3727,9 @@ def _get_top_k_continuations(
3727
3727
3728
3728
# Gather the top K scores from _all_ beams.
3729
3729
if do_sample :
3730
- # Handle potential NaN values in accumulated_log_probs
3731
- probs = nn .functional .softmax (accumulated_log_probs , dim = - 1 )
3732
- # Replace NaN values with uniform distribution
3733
- if torch .isnan (probs ).any ():
3734
- # Create a mask for NaN positions
3735
- nan_mask = torch .isnan (probs )
3736
- # Replace NaN with a small uniform probability
3737
- probs = torch .where (nan_mask , torch .ones_like (probs ) / probs .shape [- 1 ], probs )
3738
- # Renormalize to ensure probabilities sum to 1
3739
- probs = probs / probs .sum (dim = - 1 , keepdim = True )
3740
- topk_indices = torch .multinomial (probs , num_samples = beams_to_keep )
3730
+ topk_indices = torch .multinomial (
3731
+ nn .functional .softmax (accumulated_log_probs , dim = - 1 ), num_samples = beams_to_keep
3732
+ )
3741
3733
topk_log_probs = torch .gather (input = accumulated_log_probs , dim = 1 , index = topk_indices )
3742
3734
else :
3743
3735
topk_log_probs , topk_indices = torch .topk (accumulated_log_probs , k = beams_to_keep )
You can’t perform that action at this time.
0 commit comments