@@ -603,35 +603,42 @@ def prepare_inputs_for_generation(
603603 ** kwargs ,
604604 ):
605605 # Omit tokens covered by past_key_values
606+ past_length = 0
607+ token_num = (
608+ input_ids .shape [1 ] + self .config .input_token_len - 1
609+ ) // self .config .input_token_len
610+
606611 if past_key_values is not None :
607612 if isinstance (past_key_values , Cache ):
608613 past_length = past_key_values .get_seq_length ()
609614 else :
610615 past_length = past_key_values [0 ][0 ].shape [2 ]
611616
617+ if past_key_values is not None and past_length > 0 :
612618 # Keep only the unprocessed tokens:
613619 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
614620 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
615621 # input)
616- if attention_mask is not None and attention_mask .shape [1 ] > (
617- input_ids .shape [1 ] // self .config .input_token_len
618- ):
622+ if attention_mask is not None and attention_mask .shape [1 ] > token_num :
619623 input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
620624 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
621625 # input_ids based on the past_length.
622- elif past_length < (input_ids .shape [1 ] // self .config .input_token_len ):
623- input_ids = input_ids [:, past_length * self .config .input_token_len :]
626+ elif past_length < token_num :
627+ # TODO: Actually, we need to know the output_token_lens used in the last generation step.
628+ # Sundial will pad the input when it is non-divisible, so we cannot use past_length to slice input_ids
629+ input_ids = input_ids [:, - self .config .output_token_lens [0 ] :]
624630 # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
625631
626632 position_ids = kwargs .get ("position_ids" , None )
627633 if attention_mask is not None and position_ids is None :
628634 # create position_ids on the fly for batch generation
629635 position_ids = attention_mask .long ().cumsum (- 1 ) - 1
630636 position_ids .masked_fill_ (attention_mask == 0 , 1 )
631- if past_key_values :
632- position_ids = position_ids [
633- :, - (input_ids .shape [1 ] // self .config .input_token_len ) :
634- ]
637+ if past_key_values is not None and past_length > 0 :
638+ token_num = (
639+ input_ids .shape [1 ] + self .config .input_token_len - 1
640+ ) // self .config .input_token_len
641+ position_ids = position_ids [:, - token_num :]
635642
636643 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
637644 if inputs_embeds is not None and past_key_values is None :
0 commit comments