@@ -7,7 +7,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
77 opts = Keyword . validate! ( opts , [ :suppressed_token_ids ] )
88
99 indices = opts [ :suppressed_token_ids ] |> Nx . tensor ( ) |> Nx . new_axis ( - 1 )
10- values = Nx . broadcast ( Nx.Constants . neg_infinity ( ) , { Nx . size ( indices ) } )
10+ values = Nx . broadcast ( Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , { Nx . size ( indices ) } )
1111 Nx . indexed_put ( logits , indices , values )
1212 end
1313
@@ -97,7 +97,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
9797 indices = Nx . new_axis ( token_id , - 1 )
9898
9999 match? = Nx . all ( ngram_but_one == last_ngram_but_one )
100- updates = Nx . select ( match? , Nx.Constants . neg_infinity ( ) , 0 )
100+ updates = Nx . select ( match? , Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , 0 )
101101 logits = Nx . indexed_add ( logits , indices , updates )
102102
103103 { i + 1 , last_ngram_but_one , sequence , length , logits }
@@ -108,13 +108,17 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
108108 end
109109
110110 deftransformp force_token_id ( logits , token_id ) do
111- Nx.Constants . neg_infinity ( )
112- |> Nx . broadcast ( logits )
113- |> Nx . put_slice ( [ token_id ] , Nx . tensor ( [ 0 ] ) )
111+ logits
112+ |> Nx . fill ( Nx.Constants . neg_infinity ( ) , type: Nx . type ( logits ) )
113+ |> Nx . put_slice ( [ token_id ] , Nx . tensor ( [ 0 ] , type: Nx . type ( logits ) ) )
114114 end
115115
116116 deftransformp ignore_token_id ( logits , token_id ) do
117- Nx . put_slice ( logits , [ token_id ] , Nx . broadcast ( Nx.Constants . neg_infinity ( ) , { 1 } ) )
117+ Nx . put_slice (
118+ logits ,
119+ [ token_id ] ,
120+ Nx . broadcast ( Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , { 1 } )
121+ )
118122 end
119123
120124 defn temperature_processor ( logits , _context , opts \\ [ ] ) do
@@ -132,7 +136,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
132136
133137 { top_k_logits , _ } = Nx . top_k ( logits , k: top_k )
134138 kth_logit = top_k_logits [ - 1 ]
135- Nx . select ( logits < kth_logit , Nx.Constants . neg_infinity ( ) , logits )
139+ Nx . select ( logits < kth_logit , Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , logits )
136140 end
137141
138142 defn top_p_processor ( logits , _context , opts \\ [ ] ) do
@@ -152,12 +156,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
152156 # Arrange the mask back into the original logits order
153157 ignore_mask =
154158 Nx . indexed_put (
155- Nx . broadcast ( 0.0 , Nx . shape ( sorted_idx ) ) ,
159+ Nx . fill ( ordered_ignore_mask , 0 ) ,
156160 Nx . new_axis ( sorted_idx , - 1 ) ,
157161 Nx . flatten ( ordered_ignore_mask )
158162 )
159163
160- Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( ) , logits )
164+ Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , logits )
161165 end
162166
163167 defn whisper_timestamp_processor ( logits , context , opts \\ [ ] ) do
@@ -224,7 +228,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
224228 )
225229 )
226230
227- Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( ) , logits )
231+ Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , logits )
228232 end
229233
230234 defnp maybe_force_timestamp ( logits , timestamp_begin_id ) do
@@ -242,7 +246,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
242246 force_timestamp_mask = timestamp_log_probability > max_token_log_probability
243247 tokens_mask = Nx . iota ( Nx . shape ( logits ) ) < timestamp_begin_id
244248 ignore_mask = force_timestamp_mask and tokens_mask
245- Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( ) , logits )
249+ Nx . select ( ignore_mask , Nx.Constants . neg_infinity ( Nx . type ( logits ) ) , logits )
246250 end
247251
248252 deftransformp begin_idx ( forced_token_ids ) do
0 commit comments