diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index cb43f8db6fd2..ddc091336611 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -30,7 +30,7 @@ from paddlenlp.utils.log import logger -from .model_outputs import ModelOutput +from .model_outputs import CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput __all__ = ["GenerationMixin"] @@ -540,6 +540,7 @@ def generate( temperature=1.0, top_k=0, top_p=1.0, + penalty_alpha=0.6, repetition_penalty=1.0, num_beams=1, num_beam_groups=1, @@ -730,9 +731,10 @@ def generate( assert decode_strategy in [ "greedy_search", + "contrastive_search", "sampling", "beam_search", - ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format( + ], "`decode_strategy` must be one of 'greedy_search','contrastive_search', 'sampling' or 'beam_search' but received {}.".format( decode_strategy ) @@ -881,6 +883,24 @@ def generate( return self.greedy_search( input_ids, logits_processors, max_len, pad_token_id, eos_token_id, **model_kwargs ) + elif decode_strategy == 'contrastive_search': + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing" + " contrastive search." + ) + return self.contrastive_search( + input_ids, + logits_processors, + max_len, + pad_token_id, + eos_token_id, + penalty_alpha, + top_k, + top_p, + temperature, + **model_kwargs, + ) elif decode_strategy == "sampling": if num_return_sequences > 1: @@ -1030,6 +1050,183 @@ def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id, return input_ids[:, origin_len:], scores + def contrastive_search( + self, + input_ids, + logits_processors, + max_length, + pad_token_id, + eos_token_id, + penalty_alpha, + top_k=None, + top_p=None, + temperature=None, + min_tokens_to_keep=1, + **model_kwargs + ): + + logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + batch_size, cur_len = input_ids.shape + origin_len = cur_len + unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool") + scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) + + while cur_len < max_length: + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step + if model_kwargs.get("cache") is None: + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` + outputs = self(**model_inputs, return_dict=True, output_hidden_states=True) + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) + if self.is_encoder_decoder: + last_hidden_states = outputs.decoder_hidden_states[-1] + else: + last_hidden_states = outputs.hidden_states[-1] + + # next logit for contrastive search to select top-k candidate tokens + logit_for_next_step = outputs.logits[:, -1, :] + + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder + ) ## 内部获取extract_past_from_model_output + + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self.expand_inputs_for_generation( + input_ids, expand_size=top_k, **model_kwargs + ) + + past_key_values = model_kwargs.get("cache") # cache 用的对不对 + if past_key_values is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif ( + not isinstance(past_key_values[0], (tuple, paddle.Tensor)) + or past_key_values[0][0].shape[0] != batch_size + ): + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty + + # pre-process distribution + logit_for_next_step = self.adjust_logits_during_generation(logit_for_next_step) + logit_for_next_step = logits_processors(input_ids, logit_for_next_step) + + origin_probs = F.softmax(logit_for_next_step) + origin_probs = paddle.log(origin_probs) # 用哪个得分 + + if temperature is not None and temperature != 1.0: + logit_for_next_step = logit_for_next_step / temperature + + next_probs = F.softmax(logit_for_next_step) + if top_k is not None and top_k != 0: + next_probs = TopKProcess(next_probs, top_k, min_tokens_to_keep) + if top_p is not None and top_p < 1.0: + next_probs = TopPProcess(next_probs, top_p, min_tokens_to_keep) + top_k_probs, top_k_ids = paddle.topk(next_probs, k=top_k) + + # Replicates the new past_key_values to match the `top_k` candidates + new_key_values = [] + for layer in model_kwargs["cache"]: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, axis=0)) + new_key_values.append(items) + model_kwargs["cache"] = new_key_values + + # compute the candidate tokens by the language model and collects their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.reshape([-1, 1]), **model_kwargs) + outputs = self(**next_model_inputs, return_dict=True, output_hidden_states=True) + next_past_key_values = outputs.past_key_values + + logits = outputs.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if self.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + context_hidden = last_hidden_states.repeat_interleave(top_k, axis=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence + selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = paddle.index_sample(top_k_ids, selected_idx.unsqueeze(-1)) + + next_scores = paddle.index_sample(origin_probs, next_tokens) # 用哪个得分 + + next_hidden = paddle.index_select(next_hidden, selected_idx) + last_hidden_states = paddle.concat([last_hidden_states, next_hidden], axis=1) + + next_decoder_hidden_states = () + for layer in full_hidden_states: + layer = paddle.index_select(layer, selected_idx) + next_decoder_hidden_states += (layer,) + + # select the past_key_value + new_key_values = () + for layer in next_past_key_values: + items = () + # item is either the key or the value matrix + for item in layer: + item = paddle.index_select(item, selected_idx) # [B, K, num_head, seq_len, esz] + # item = item[:, selected_idx, ...] # [B, num_head, seq_len, esz] + items += (item,) + new_key_values += (items,) + next_past_key_values = new_key_values + + logit_for_next_step = paddle.index_select(logits, selected_idx) + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + if self.is_encoder_decoder: + outputs = Seq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + ) + else: + outputs = CausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states + ) + + if eos_token_id is not None: + next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) + + scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag) + + cur_len += 1 + input_ids = paddle.concat([input_ids, next_tokens], axis=1) + + if eos_token_id is not None: + unfinished_flag = paddle.logical_and(unfinished_flag, next_tokens != eos_token_id) + + # Stop when there is a in all sentences + if not paddle.any(unfinished_flag): + break + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder + ) + return input_ids[:, origin_len:], scores + def sample( self, input_ids, @@ -1738,3 +1935,26 @@ def TopPProcess(probs, top_p, min_tokens_to_keep): condition = paddle.cast(condition, "bool").reshape(probs.shape) probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs) return probs + + +def _ranking_fast( + context_hidden, + next_hidden, + next_top_k_probs, + alpha: float, + beam_width: int, +): + """ + Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described + in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each + row in the batch. + """ + norm_context_hidden = context_hidden / context_hidden.norm(axis=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(axis=2, keepdim=True) + consine_matrix = paddle.matmul(norm_context_hidden, norm_next_hidden.transpose([0, 2, 1])).squeeze(-1) + degeneration_penalty = paddle.max(consine_matrix, axis=-1) + next_top_k_probs = next_top_k_probs.reshape([-1]) + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = paddle.stack(paddle.split(contrastive_score, beam_width), axis=-1) + selected_idx = paddle.argmax(contrastive_score, axis=-1) + return selected_idx diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index 18fc6d173383..3b0ebb0ad3e2 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -250,6 +250,59 @@ def _greedy_generate( ) return output_greedy, output_generate + def _contrastive_generate( + self, + model, + input_ids, + attention_mask, + max_length, + num_return_sequences, + logits_processors, + logits_warper, + process_kwargs, + ): + with paddle.no_grad(): + output_generate = model.generate( + input_ids, + max_length=max_length, + decode_strategy="contrastive_search", + num_return_sequences=num_return_sequences, + attention_mask=attention_mask, + penalty_alpha=0.6, + top_k=1, + **process_kwargs, + ) + + kwargs = {} + if self.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + num_interleave=num_return_sequences, + ) + kwargs["encoder_output"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, axis=0) + attention_mask_clone = attention_mask_clone.repeat_interleave(num_return_sequences, axis=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, axis=0) + input_ids_clone = input_ids.repeat_interleave(num_return_sequences, axis=0) + + with paddle.no_grad(): + output_sample = model.contrastive_search( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length + 1 if self.is_encoder_decoder else max_length + input_ids.shape[-1], + logits_processors=logits_processors, + pad_token_id=getattr(model, model.base_model_prefix).config["pad_token_id"], + eos_token_id=getattr(model, model.base_model_prefix).config["eos_token_id"], + penalty_alpha=0.6, + top_k=1, + **process_kwargs, + **kwargs, + ) + return output_sample, output_generate + def _sample_generate( self, model, @@ -424,6 +477,43 @@ def test_greedy_generate(self): self.assertListEqual(output_greedy[0].tolist(), output_generate[0].tolist()) + def test_contrastive_generate(self): + + for model_class in self.all_generative_model_classes.keys(): + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + paddle.seed(124) + model = self._make_model_instance(config, model_class) + model.eval() + + if self.is_encoder_decoder: + max_length = 4 + + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + getattr(model, model.base_model_prefix).config["eos_token_id"], + forced_bos_token_id=getattr( + getattr(model, model.base_model_prefix).config, "forced_bos_token_id", None + ), + forced_eos_token_id=getattr( + getattr(model, model.base_model_prefix).config, "forced_eos_token_id", None + ), + max_length=max_length, + plus_length=1 if self.is_encoder_decoder else input_ids.shape[-1], + ) + logits_warper = self._get_warper_and_kwargs() + + # check `generate()` and `sample()` are equal + output_sample, output_generate = self._contrastive_generate( + model=model, + input_ids=input_ids, + attention_mask=attention_mask, + max_length=max_length, + num_return_sequences=1, + logits_processors=logits_processor, + logits_warper=logits_warper, + process_kwargs=process_kwargs, + ) + self.assertListEqual(output_sample[0].tolist(), output_generate[0].tolist()) + def test_sample_generate(self): for model_class in self.all_generative_model_classes.keys():