Skip to content

Commit 7b0d06c

Browse files
Make sure the initial decoding cache has the proper types (#346)
1 parent 601551a commit 7b0d06c

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

lib/bumblebee/text/generation.ex

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ defmodule Bumblebee.Text.Generation do
249249
})
250250

251251
max_length = max_length_fun.(1)
252-
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, max_length)
252+
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
253253
{inputs, inputs["decoder_input_ids"], max_length}
254254
end
255255

@@ -260,7 +260,7 @@ defmodule Bumblebee.Text.Generation do
260260
prepare_inputs_fun = fn inputs, _params ->
261261
sequence_length = Nx.axis_size(inputs["input_ids"], 1)
262262
max_length = max_length_fun.(sequence_length)
263-
inputs = prepare_decoder_inputs(inputs, "", spec, max_length)
263+
inputs = prepare_decoder_inputs(inputs, "", spec, model, max_length)
264264
{inputs, inputs["input_ids"], max_length}
265265
end
266266

@@ -279,7 +279,7 @@ defmodule Bumblebee.Text.Generation do
279279
inputs["input_ids"] || inputs["input_features"] || inputs["pixel_values"]
280280
end
281281

282-
defp prepare_decoder_inputs(inputs, prefix, spec, max_length) do
282+
defp prepare_decoder_inputs(inputs, prefix, spec, model, max_length) do
283283
input_ids = inputs[prefix <> "input_ids"]
284284
attention_mask = inputs[prefix <> "attention_mask"] || Nx.broadcast(1, input_ids)
285285

@@ -295,9 +295,32 @@ defmodule Bumblebee.Text.Generation do
295295

296296
batch_size = Nx.axis_size(input_ids, 0)
297297
cache = init_cache(spec, batch_size, max_length, inputs)
298+
299+
output_policy = model_output_policy(model)
300+
301+
# TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
302+
# match Axon compiler
303+
304+
# Cast all float cache tensors to match the model output. This way
305+
# we make sure the cache we pass as input has the same types as
306+
# the updated cache returned from the model
307+
cache =
308+
Bumblebee.Utils.Nx.map(cache, fn tensor ->
309+
if Nx.Type.integer?(Nx.type(tensor)) do
310+
tensor
311+
else
312+
Axon.MixedPrecision.cast(output_policy, tensor, :output)
313+
end
314+
end)
315+
298316
Map.put(inputs, "cache", cache)
299317
end
300318

319+
defp model_output_policy(model) do
320+
{node, _} = Axon.pop_node(model)
321+
node.policy
322+
end
323+
301324
defp update_decoder_inputs(prefix, inputs, cache, token_ids) do
302325
inputs
303326
|> Map.replace!(prefix <> "input_ids", token_ids)

test/bumblebee/text/bart_test.exs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,32 @@ defmodule Bumblebee.Text.BartTest do
154154

155155
assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
156156
end
157+
158+
test "generation with :for_conditional_generation and lower precision" do
159+
assert {:ok, %{model: model, params: params, spec: spec}} =
160+
Bumblebee.load_model(
161+
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"},
162+
type: :f16
163+
)
164+
165+
{:ok, generation_config} =
166+
Bumblebee.load_generation_config(
167+
{:hf, "hf-internal-testing/tiny-random-BartForConditionalGeneration"}
168+
)
169+
170+
assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = spec
171+
172+
inputs = %{
173+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
174+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
175+
"seed" => Nx.tensor([0])
176+
}
177+
178+
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)
179+
180+
generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config)
181+
%{token_ids: token_ids} = generate.(params, inputs)
182+
183+
assert_equal(token_ids, Nx.tensor([[988, 988, 988]]))
184+
end
157185
end

0 commit comments

Comments
 (0)