@@ -243,6 +243,7 @@ defmodule Bumblebee.Text.Generation do
243243 prepare_inputs_fun = fn inputs , params ->
244244 encoder_outputs = encoder_predict_fun . ( params , inputs )
245245
246+ padded_batch_item? = padded_batch_item? ( encoder_input ( inputs ) )
246247 batch_size = Nx . axis_size ( encoder_input ( inputs ) , 0 )
247248
248249 inputs = Map . put ( inputs , "encoder_hidden_state" , encoder_outputs . hidden_state )
@@ -254,18 +255,19 @@ defmodule Bumblebee.Text.Generation do
254255
255256 max_length = max_length_fun . ( 1 )
256257 inputs = prepare_decoder_inputs ( inputs , "decoder_" , spec , model , max_length )
257- { inputs , inputs [ "decoder_input_ids" ] , max_length }
258+ { inputs , inputs [ "decoder_input_ids" ] , padded_batch_item? , max_length }
258259 end
259260
260261 update_inputs_fun = & update_decoder_inputs ( "decoder_" , & 1 , & 2 , & 3 )
261262
262263 { prepare_inputs_fun , update_inputs_fun }
263264 else
264265 prepare_inputs_fun = fn inputs , _params ->
266+ padded_batch_item? = padded_batch_item? ( inputs [ "input_ids" ] )
265267 sequence_length = Nx . axis_size ( inputs [ "input_ids" ] , 1 )
266268 max_length = max_length_fun . ( sequence_length )
267269 inputs = prepare_decoder_inputs ( inputs , "" , spec , model , max_length )
268- { inputs , inputs [ "input_ids" ] , max_length }
270+ { inputs , inputs [ "input_ids" ] , padded_batch_item? , max_length }
269271 end
270272
271273 update_inputs_fun = & update_decoder_inputs ( "" , & 1 , & 2 , & 3 )
@@ -283,6 +285,13 @@ defmodule Bumblebee.Text.Generation do
283285 inputs [ "input_ids" ] || inputs [ "input_features" ] || inputs [ "pixel_values" ]
284286 end
285287
288+ defp padded_batch_item? ( input ) do
289+ [ _ | non_batch_axes ] = Nx . axes ( input )
290+ # We check each batch item if it is full of zeros, in which case
291+ # case we assume it's padding, not an actual input.
292+ input |> Nx . equal ( 0 ) |> Nx . all ( axes: non_batch_axes )
293+ end
294+
286295 defp prepare_decoder_inputs ( inputs , prefix , spec , model , max_length ) do
287296 input_ids = inputs [ prefix <> "input_ids" ]
288297 attention_mask = inputs [ prefix <> "attention_mask" ] || Nx . broadcast ( 1 , input_ids )
@@ -396,7 +405,8 @@ defmodule Bumblebee.Text.Generation do
396405 ) do
397406 { seed , inputs } = pop_seed ( inputs )
398407
399- { decoder_inputs , decoder_input_ids , max_length } = prepare_inputs_fun . ( inputs , params )
408+ { decoder_inputs , decoder_input_ids , padded_batch_item? , max_length } =
409+ prepare_inputs_fun . ( inputs , params )
400410
401411 length = Nx . axis_size ( decoder_input_ids , 1 )
402412
@@ -414,6 +424,7 @@ defmodule Bumblebee.Text.Generation do
414424 greedy (
415425 decoder_inputs ,
416426 decoder_input_ids ,
427+ padded_batch_item? ,
417428 predict_fun ,
418429 params ,
419430 logits_processor_fun ,
@@ -425,6 +436,7 @@ defmodule Bumblebee.Text.Generation do
425436 contrastive (
426437 decoder_inputs ,
427438 decoder_input_ids ,
439+ padded_batch_item? ,
428440 predict_fun ,
429441 params ,
430442 logits_processor_fun ,
@@ -440,6 +452,7 @@ defmodule Bumblebee.Text.Generation do
440452 sampling (
441453 decoder_inputs ,
442454 decoder_input_ids ,
455+ padded_batch_item? ,
443456 predict_fun ,
444457 params ,
445458 seed ,
@@ -469,6 +482,7 @@ defmodule Bumblebee.Text.Generation do
469482 defnp greedy (
470483 inputs ,
471484 decoder_input_ids ,
485+ padded_batch_item? ,
472486 predict_fun ,
473487 params ,
474488 logits_processor_fun ,
@@ -479,7 +493,7 @@ defmodule Bumblebee.Text.Generation do
479493 pad_token_id = opts [ :pad_token_id ]
480494 eos_token_id = opts [ :eos_token_id ]
481495
482- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
496+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
483497
484498 # The loop works with inputs of length 1, so if the initial input
485499 # is longer, we make the initial pass outside
@@ -519,15 +533,17 @@ defmodule Bumblebee.Text.Generation do
519533 state
520534 end
521535
522- defnp init_sequences ( decoder_input_ids , max_length , pad_token_id ) do
536+ defnp init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id ) do
523537 { batch_size , length } = Nx . shape ( decoder_input_ids )
524538
525539 sequences = Nx . broadcast ( pad_token_id , { batch_size , max_length } )
526540 sequences = Nx . put_slice ( sequences , [ 0 , 0 ] , decoder_input_ids )
527541
528542 # For each sequence, we keep track of its final length, where 0
529- # means that it has not been finished yet
530- finished_length = Nx . broadcast ( 0 , { batch_size } )
543+ # means that it has not been finished yet. If there are padding
544+ # batch inputs, we immediately mark them as finished, otherwise
545+ # they could produce arbitrary tokens until we reach max length.
546+ finished_length = Nx . select ( padded_batch_item? , 1 , 0 )
531547
532548 % {
533549 sequences: sequences ,
@@ -631,6 +647,7 @@ defmodule Bumblebee.Text.Generation do
631647 defnp contrastive (
632648 inputs ,
633649 decoder_input_ids ,
650+ padded_batch_item? ,
634651 predict_fun ,
635652 params ,
636653 logits_processor_fun ,
@@ -644,7 +661,7 @@ defmodule Bumblebee.Text.Generation do
644661 top_k = opts [ :top_k ]
645662 penalty_alpha = opts [ :penalty_alpha ]
646663
647- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
664+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
648665
649666 # Step (1)
650667 # Initial pass to obtain hidden state and expand inputs to top-k
@@ -796,6 +813,7 @@ defmodule Bumblebee.Text.Generation do
796813 defnp sampling (
797814 inputs ,
798815 decoder_input_ids ,
816+ padded_batch_item? ,
799817 predict_fun ,
800818 params ,
801819 seed ,
@@ -807,7 +825,7 @@ defmodule Bumblebee.Text.Generation do
807825 pad_token_id = opts [ :pad_token_id ]
808826 eos_token_id = opts [ :eos_token_id ]
809827
810- state = init_sequences ( decoder_input_ids , max_length , pad_token_id )
828+ state = init_sequences ( decoder_input_ids , padded_batch_item? , max_length , pad_token_id )
811829
812830 prng_key = seed |> Nx . vectorize ( :batch ) |> Nx.Random . key ( )
813831
0 commit comments