@@ -57,16 +57,17 @@ def sample_manual_loop_no_classes(
5757 if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step :
5858 eos_score = cfg_logits [:, eos_token_id ].clone ()
5959
60+ remove_logit_value = torch .finfo (cfg_logits .dtype ).min
6061 # Only generate audio tokens
61- cfg_logits [:, :audio_start_id ] = float ( '-inf' )
62+ cfg_logits [:, :audio_start_id ] = remove_logit_value
6263
6364 if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step :
6465 cfg_logits [:, eos_token_id ] = eos_score
6566
6667 if top_k is not None and top_k > 0 :
6768 top_k_vals , _ = torch .topk (cfg_logits , top_k )
6869 min_val = top_k_vals [..., - 1 , None ]
69- cfg_logits [cfg_logits < min_val ] = float ( '-inf' )
70+ cfg_logits [cfg_logits < min_val ] = remove_logit_value
7071
7172 if top_p is not None and top_p < 1.0 :
7273 sorted_logits , sorted_indices = torch .sort (cfg_logits , descending = True )
@@ -75,7 +76,7 @@ def sample_manual_loop_no_classes(
7576 sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [..., :- 1 ].clone ()
7677 sorted_indices_to_remove [..., 0 ] = 0
7778 indices_to_remove = sorted_indices_to_remove .scatter (1 , sorted_indices , sorted_indices_to_remove )
78- cfg_logits [indices_to_remove ] = float ( '-inf' )
79+ cfg_logits [indices_to_remove ] = remove_logit_value
7980
8081 if temperature > 0 :
8182 cfg_logits = cfg_logits / temperature
0 commit comments