Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 222 additions & 2 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from paddlenlp.utils.log import logger

from .model_outputs import ModelOutput
from .model_outputs import CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput

__all__ = ["GenerationMixin"]

Expand Down Expand Up @@ -540,6 +540,7 @@ def generate(
temperature=1.0,
top_k=0,
top_p=1.0,
penalty_alpha=0.6,
repetition_penalty=1.0,
num_beams=1,
num_beam_groups=1,
Expand Down Expand Up @@ -730,9 +731,10 @@ def generate(

assert decode_strategy in [
"greedy_search",
"contrastive_search",
"sampling",
"beam_search",
], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format(
], "`decode_strategy` must be one of 'greedy_search','contrastive_search', 'sampling' or 'beam_search' but received {}.".format(
decode_strategy
)

Expand Down Expand Up @@ -881,6 +883,24 @@ def generate(
return self.greedy_search(
input_ids, logits_processors, max_len, pad_token_id, eos_token_id, **model_kwargs
)
elif decode_strategy == 'contrastive_search':
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing"
" contrastive search."
)
return self.contrastive_search(
input_ids,
logits_processors,
max_len,
pad_token_id,
eos_token_id,
penalty_alpha,
top_k,
top_p,
temperature,
**model_kwargs,
)

elif decode_strategy == "sampling":
if num_return_sequences > 1:
Expand Down Expand Up @@ -1030,6 +1050,183 @@ def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id,

return input_ids[:, origin_len:], scores

def contrastive_search(
self,
input_ids,
logits_processors,
max_length,
pad_token_id,
eos_token_id,
penalty_alpha,
top_k=None,
top_p=None,
temperature=None,
min_tokens_to_keep=1,
**model_kwargs
):

logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList()
batch_size, cur_len = input_ids.shape
origin_len = cur_len
unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool")
scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype())

while cur_len < max_length:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("cache") is None:
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
outputs = self(**model_inputs, return_dict=True, output_hidden_states=True)

# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.is_encoder_decoder:
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = outputs.hidden_states[-1]

# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]

model_kwargs = self.update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
) ## 内部获取extract_past_from_model_output

# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self.expand_inputs_for_generation(
input_ids, expand_size=top_k, **model_kwargs
)

past_key_values = model_kwargs.get("cache") # cache 用的对不对
if past_key_values is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif (
not isinstance(past_key_values[0], (tuple, paddle.Tensor))
or past_key_values[0][0].shape[0] != batch_size
):
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
)

# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty

# pre-process distribution
logit_for_next_step = self.adjust_logits_during_generation(logit_for_next_step)
logit_for_next_step = logits_processors(input_ids, logit_for_next_step)

origin_probs = F.softmax(logit_for_next_step)
origin_probs = paddle.log(origin_probs) # 用哪个得分

if temperature is not None and temperature != 1.0:
logit_for_next_step = logit_for_next_step / temperature

next_probs = F.softmax(logit_for_next_step)
if top_k is not None and top_k != 0:
next_probs = TopKProcess(next_probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
next_probs = TopPProcess(next_probs, top_p, min_tokens_to_keep)
top_k_probs, top_k_ids = paddle.topk(next_probs, k=top_k)

# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
for layer in model_kwargs["cache"]:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, axis=0))
new_key_values.append(items)
model_kwargs["cache"] = new_key_values

# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.reshape([-1, 1]), **model_kwargs)
outputs = self(**next_model_inputs, return_dict=True, output_hidden_states=True)
next_past_key_values = outputs.past_key_values

logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
context_hidden = last_hidden_states.repeat_interleave(top_k, axis=0)

# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)

# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = paddle.index_sample(top_k_ids, selected_idx.unsqueeze(-1))

next_scores = paddle.index_sample(origin_probs, next_tokens) # 用哪个得分

next_hidden = paddle.index_select(next_hidden, selected_idx)
last_hidden_states = paddle.concat([last_hidden_states, next_hidden], axis=1)

next_decoder_hidden_states = ()
for layer in full_hidden_states:
layer = paddle.index_select(layer, selected_idx)
next_decoder_hidden_states += (layer,)

# select the past_key_value
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = paddle.index_select(item, selected_idx) # [B, K, num_head, seq_len, esz]
# item = item[:, selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values

logit_for_next_step = paddle.index_select(logits, selected_idx)

# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.is_encoder_decoder:
outputs = Seq2SeqLMOutput(
past_key_values=next_past_key_values,
decoder_hidden_states=next_decoder_hidden_states,
)
else:
outputs = CausalLMOutputWithPast(
past_key_values=next_past_key_values,
hidden_states=next_decoder_hidden_states
)

if eos_token_id is not None:
next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id))

scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag)

cur_len += 1
input_ids = paddle.concat([input_ids, next_tokens], axis=1)

if eos_token_id is not None:
unfinished_flag = paddle.logical_and(unfinished_flag, next_tokens != eos_token_id)

# Stop when there is a </s> in all sentences
if not paddle.any(unfinished_flag):
break
model_kwargs = self.update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder
)
return input_ids[:, origin_len:], scores

def sample(
self,
input_ids,
Expand Down Expand Up @@ -1738,3 +1935,26 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
condition = paddle.cast(condition, "bool").reshape(probs.shape)
probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs)
return probs


def _ranking_fast(
context_hidden,
next_hidden,
next_top_k_probs,
alpha: float,
beam_width: int,
):
"""
Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
row in the batch.
"""
norm_context_hidden = context_hidden / context_hidden.norm(axis=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(axis=2, keepdim=True)
consine_matrix = paddle.matmul(norm_context_hidden, norm_next_hidden.transpose([0, 2, 1])).squeeze(-1)
degeneration_penalty = paddle.max(consine_matrix, axis=-1)
next_top_k_probs = next_top_k_probs.reshape([-1])
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = paddle.stack(paddle.split(contrastive_score, beam_width), axis=-1)
selected_idx = paddle.argmax(contrastive_score, axis=-1)
return selected_idx
90 changes: 90 additions & 0 deletions tests/transformers/test_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,59 @@ def _greedy_generate(
)
return output_greedy, output_generate

def _contrastive_generate(
self,
model,
input_ids,
attention_mask,
max_length,
num_return_sequences,
logits_processors,
logits_warper,
process_kwargs,
):
with paddle.no_grad():
output_generate = model.generate(
input_ids,
max_length=max_length,
decode_strategy="contrastive_search",
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
penalty_alpha=0.6,
top_k=1,
**process_kwargs,
)

kwargs = {}
if self.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
)
kwargs["encoder_output"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, axis=0)
attention_mask_clone = attention_mask_clone.repeat_interleave(num_return_sequences, axis=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, axis=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, axis=0)

with paddle.no_grad():
output_sample = model.contrastive_search(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length + 1 if self.is_encoder_decoder else max_length + input_ids.shape[-1],
logits_processors=logits_processors,
pad_token_id=getattr(model, model.base_model_prefix).config["pad_token_id"],
eos_token_id=getattr(model, model.base_model_prefix).config["eos_token_id"],
penalty_alpha=0.6,
top_k=1,
**process_kwargs,
**kwargs,
)
return output_sample, output_generate

def _sample_generate(
self,
model,
Expand Down Expand Up @@ -424,6 +477,43 @@ def test_greedy_generate(self):

self.assertListEqual(output_greedy[0].tolist(), output_generate[0].tolist())

def test_contrastive_generate(self):

for model_class in self.all_generative_model_classes.keys():
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
paddle.seed(124)
model = self._make_model_instance(config, model_class)
model.eval()

if self.is_encoder_decoder:
max_length = 4

process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
getattr(model, model.base_model_prefix).config["eos_token_id"],
forced_bos_token_id=getattr(
getattr(model, model.base_model_prefix).config, "forced_bos_token_id", None
),
forced_eos_token_id=getattr(
getattr(model, model.base_model_prefix).config, "forced_eos_token_id", None
),
max_length=max_length,
plus_length=1 if self.is_encoder_decoder else input_ids.shape[-1],
)
logits_warper = self._get_warper_and_kwargs()

# check `generate()` and `sample()` are equal
output_sample, output_generate = self._contrastive_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
logits_processors=logits_processor,
logits_warper=logits_warper,
process_kwargs=process_kwargs,
)
self.assertListEqual(output_sample[0].tolist(), output_generate[0].tolist())

def test_sample_generate(self):

for model_class in self.all_generative_model_classes.keys():
Expand Down