@@ -249,7 +249,7 @@ defmodule Bumblebee.Text.Generation do
249249 } )
250250
251251 max_length = max_length_fun . ( 1 )
252- inputs = prepare_decoder_inputs ( inputs , "decoder_" , spec , max_length )
252+ inputs = prepare_decoder_inputs ( inputs , "decoder_" , spec , model , max_length )
253253 { inputs , inputs [ "decoder_input_ids" ] , max_length }
254254 end
255255
@@ -260,7 +260,7 @@ defmodule Bumblebee.Text.Generation do
260260 prepare_inputs_fun = fn inputs , _params ->
261261 sequence_length = Nx . axis_size ( inputs [ "input_ids" ] , 1 )
262262 max_length = max_length_fun . ( sequence_length )
263- inputs = prepare_decoder_inputs ( inputs , "" , spec , max_length )
263+ inputs = prepare_decoder_inputs ( inputs , "" , spec , model , max_length )
264264 { inputs , inputs [ "input_ids" ] , max_length }
265265 end
266266
@@ -279,7 +279,7 @@ defmodule Bumblebee.Text.Generation do
279279 inputs [ "input_ids" ] || inputs [ "input_features" ] || inputs [ "pixel_values" ]
280280 end
281281
282- defp prepare_decoder_inputs ( inputs , prefix , spec , max_length ) do
282+ defp prepare_decoder_inputs ( inputs , prefix , spec , model , max_length ) do
283283 input_ids = inputs [ prefix <> "input_ids" ]
284284 attention_mask = inputs [ prefix <> "attention_mask" ] || Nx . broadcast ( 1 , input_ids )
285285
@@ -295,9 +295,32 @@ defmodule Bumblebee.Text.Generation do
295295
296296 batch_size = Nx . axis_size ( input_ids , 0 )
297297 cache = init_cache ( spec , batch_size , max_length , inputs )
298+
299+ output_policy = model_output_policy ( model )
300+
301+ # TODO: fix Axon.MixedPrecision.cast/2 to not cast integers, to
302+ # match Axon compiler
303+
304+ # Cast all float cache tensors to match the model output. This way
305+ # we make sure the cache we pass as input has the same types as
306+ # the updated cache returned from the model
307+ cache =
308+ Bumblebee.Utils.Nx . map ( cache , fn tensor ->
309+ if Nx.Type . integer? ( Nx . type ( tensor ) ) do
310+ tensor
311+ else
312+ Axon.MixedPrecision . cast ( output_policy , tensor , :output )
313+ end
314+ end )
315+
298316 Map . put ( inputs , "cache" , cache )
299317 end
300318
319+ defp model_output_policy ( model ) do
320+ { node , _ } = Axon . pop_node ( model )
321+ node . policy
322+ end
323+
301324 defp update_decoder_inputs ( prefix , inputs , cache , token_ids ) do
302325 inputs
303326 |> Map . replace! ( prefix <> "input_ids" , token_ids )
0 commit comments