@@ -603,38 +603,36 @@ def prepare_inputs_for_generation(
603603 ** kwargs ,
604604 ):
605605 # Omit tokens covered by past_key_values
606- past_length = 0
607- token_num = (
608- input_ids .shape [1 ] + self .config .input_token_len - 1
609- ) // self .config .input_token_len
610-
611606 if past_key_values is not None :
612607 if isinstance (past_key_values , Cache ):
613608 past_length = past_key_values .get_seq_length ()
614609 else :
615610 past_length = past_key_values [0 ][0 ].shape [2 ]
616611
617- if past_key_values is not None and past_length > 0 :
618612 # Keep only the unprocessed tokens:
619613 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
620614 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
621615 # input)
622- if attention_mask is not None and attention_mask .shape [1 ] > token_num :
623- input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
616+ if attention_mask is not None and attention_mask .shape [1 ] > (
617+ input_ids .shape [1 ] // self .config .input_token_len
618+ ):
619+ input_ids = input_ids [
620+ :,
621+ - (attention_mask .shape [1 ] - past_length )
622+ * self .config .input_token_len :,
623+ ]
624624 # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
625625 # input_ids based on the past_length.
626- elif past_length < token_num :
627- # TODO: Actually, we need to know the output_token_lens used in the last generation step.
628- # Sundial will pad the input when it is non-divisible, so we cannot use past_length to slice input_ids
629- input_ids = input_ids [:, - self .config .output_token_lens [0 ] :]
626+ elif past_length < (input_ids .shape [1 ] // self .config .input_token_len ):
627+ input_ids = input_ids [:, past_length * self .config .input_token_len :]
630628 # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
631629
632630 position_ids = kwargs .get ("position_ids" , None )
633631 if attention_mask is not None and position_ids is None :
634632 # create position_ids on the fly for batch generation
635633 position_ids = attention_mask .long ().cumsum (- 1 ) - 1
636634 position_ids .masked_fill_ (attention_mask == 0 , 1 )
637- if past_key_values is not None and past_length > 0 :
635+ if past_key_values is not None :
638636 token_num = (
639637 input_ids .shape [1 ] + self .config .input_token_len - 1
640638 ) // self .config .input_token_len
0 commit comments