Skip to content

Commit 05401ed

Browse files
committed
Preserve input type in logits processing
1 parent 3bca4c5 commit 05401ed

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

lib/bumblebee/text/generation/logits_processing.ex

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
77
opts = Keyword.validate!(opts, [:suppressed_token_ids])
88

99
indices = opts[:suppressed_token_ids] |> Nx.tensor() |> Nx.new_axis(-1)
10-
values = Nx.broadcast(Nx.Constants.neg_infinity(), {Nx.size(indices)})
10+
values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(indices)})
1111
Nx.indexed_put(logits, indices, values)
1212
end
1313

@@ -97,7 +97,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
9797
indices = Nx.new_axis(token_id, -1)
9898

9999
match? = Nx.all(ngram_but_one == last_ngram_but_one)
100-
updates = Nx.select(match?, Nx.Constants.neg_infinity(), 0)
100+
updates = Nx.select(match?, Nx.Constants.neg_infinity(Nx.type(logits)), 0)
101101
logits = Nx.indexed_add(logits, indices, updates)
102102

103103
{i + 1, last_ngram_but_one, sequence, length, logits}
@@ -108,13 +108,17 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
108108
end
109109

110110
deftransformp force_token_id(logits, token_id) do
111-
Nx.Constants.neg_infinity()
112-
|> Nx.broadcast(logits)
113-
|> Nx.put_slice([token_id], Nx.tensor([0]))
111+
logits
112+
|> Nx.fill(Nx.Constants.neg_infinity(), type: Nx.type(logits))
113+
|> Nx.put_slice([token_id], Nx.tensor([0], type: Nx.type(logits)))
114114
end
115115

116116
deftransformp ignore_token_id(logits, token_id) do
117-
Nx.put_slice(logits, [token_id], Nx.broadcast(Nx.Constants.neg_infinity(), {1}))
117+
Nx.put_slice(
118+
logits,
119+
[token_id],
120+
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {1})
121+
)
118122
end
119123

120124
defn temperature_processor(logits, _context, opts \\ []) do
@@ -132,7 +136,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
132136

133137
{top_k_logits, _} = Nx.top_k(logits, k: top_k)
134138
kth_logit = top_k_logits[-1]
135-
Nx.select(logits < kth_logit, Nx.Constants.neg_infinity(), logits)
139+
Nx.select(logits < kth_logit, Nx.Constants.neg_infinity(Nx.type(logits)), logits)
136140
end
137141

138142
defn top_p_processor(logits, _context, opts \\ []) do
@@ -152,12 +156,12 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
152156
# Arrange the mask back into the original logits order
153157
ignore_mask =
154158
Nx.indexed_put(
155-
Nx.broadcast(0.0, Nx.shape(sorted_idx)),
159+
Nx.fill(ordered_ignore_mask, 0),
156160
Nx.new_axis(sorted_idx, -1),
157161
Nx.flatten(ordered_ignore_mask)
158162
)
159163

160-
Nx.select(ignore_mask, Nx.Constants.neg_infinity(), logits)
164+
Nx.select(ignore_mask, Nx.Constants.neg_infinity(Nx.type(logits)), logits)
161165
end
162166

163167
defn whisper_timestamp_processor(logits, context, opts \\ []) do
@@ -224,7 +228,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
224228
)
225229
)
226230

227-
Nx.select(ignore_mask, Nx.Constants.neg_infinity(), logits)
231+
Nx.select(ignore_mask, Nx.Constants.neg_infinity(Nx.type(logits)), logits)
228232
end
229233

230234
defnp maybe_force_timestamp(logits, timestamp_begin_id) do
@@ -242,7 +246,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
242246
force_timestamp_mask = timestamp_log_probability > max_token_log_probability
243247
tokens_mask = Nx.iota(Nx.shape(logits)) < timestamp_begin_id
244248
ignore_mask = force_timestamp_mask and tokens_mask
245-
Nx.select(ignore_mask, Nx.Constants.neg_infinity(), logits)
249+
Nx.select(ignore_mask, Nx.Constants.neg_infinity(Nx.type(logits)), logits)
246250
end
247251

248252
deftransformp begin_idx(forced_token_ids) do

test/bumblebee/text/generation/logits_processing_test.exs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
153153
logits
154154
)
155155
end
156+
157+
test "keeps input type" do
158+
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0], type: :f16)
159+
160+
context = context([2, 3, 2, 0])
161+
162+
result = LogitsProcessing.no_repeat_ngram_processor(logits, context, ngram_length: 2)
163+
assert Nx.type(result) == {:f, 16}
164+
end
156165
end
157166

158167
describe "temperature_processor/3" do

0 commit comments

Comments
 (0)