diff --git a/applications/text_classification/hierarchical/README.md b/applications/text_classification/hierarchical/README.md index 98ae356572ac..9c715f02e29c 100644 --- a/applications/text_classification/hierarchical/README.md +++ b/applications/text_classification/hierarchical/README.md @@ -280,11 +280,11 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 预测错误的样本保存在bad_case.txt文件中: ```text -Text Label Prediction -据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇 -冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇 -6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易 -有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员 +Text Label Prediction +据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇 +冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇 +6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易 +有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员 ... ``` @@ -301,9 +301,9 @@ text: 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已 predict label: 组织关系,组织关系##解雇 label: 组织关系,组织关系##加盟,组织关系##裁员 examples with positive influence -support1 text: 尼克斯官方今日宣布,他们已经裁掉了前锋扎克-欧文,后者昨日才与尼克斯签约。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.99357 -support2 text: 活塞官方今日宣布,他们已经签下了克雷格-斯沃德,并且裁掉了托德-威瑟斯。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98344 -support3 text: 孟菲斯灰熊今年宣布,球队已经签下后卫达斯蒂-汉纳斯(DustyHannahs,版头图)并裁掉马特-穆尼。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98219 +support1 text: 尼克斯官方今日宣布,他们已经裁掉了前锋扎克-欧文,后者昨日才与尼克斯签约。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.99357 +support2 text: 活塞官方今日宣布,他们已经签下了克雷格-斯沃德,并且裁掉了托德-威瑟斯。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98344 +support3 text: 孟菲斯灰熊今年宣布,球队已经签下后卫达斯蒂-汉纳斯(DustyHannahs,版头图)并裁掉马特-穆尼。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98219 ... ``` diff --git a/applications/text_classification/multi_label/README.md b/applications/text_classification/multi_label/README.md index 94896cc8cc14..532399886e7e 100644 --- a/applications/text_classification/multi_label/README.md +++ b/applications/text_classification/multi_label/README.md @@ -276,7 +276,7 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 预测错误的样本保存在bad_case.txt文件中: ```text -Text Label Prediction +Text Label Prediction 2014年,王X以其与肖X协议离婚时未分割该套楼房的首付款为由,起诉至法院,要求分得楼房的首付款15万元。 不动产分割,有夫妻共同财产 不动产分割 但原、被告对已建立起的夫妻感情不够珍惜,因琐事即发生吵闹并最终分居,对夫妻感情造成了严重的影响,现原、被告已分居六年有余,且经人民法院判决不准离婚后仍未和好,夫妻感情确已破裂,依法应准予原、被告离婚。 二次起诉离婚,准予离婚,婚后分居,法定离婚 婚后分居,准予离婚 婚后生有一女,取名彭某乙,已11岁,现已由被告从铁炉白族乡中心小学转入走马镇李桥小学读书。 婚后有子女 婚后有子女,限制行为能力子女抚养 @@ -295,9 +295,9 @@ text: 2015年2月23日,被告将原告赶出家门,原告居住于娘家待 predict label: 婚后分居 label: 不履行家庭义务,婚后分居 examples with positive influence -support1 text: 2014年中秋节原告回了娘家,原、被告分居至今。 label: 婚后分居 score: 0.99942 -support2 text: 原告于2013年8月13日离开被告家,分居至今。 label: 婚后分居 score: 0.99916 -support3 text: 2014年4月,被告外出务工,双方分居至今。 label: 婚后分居 score: 0.99902 +support1 text: 2014年中秋节原告回了娘家,原、被告分居至今。 label: 婚后分居 score: 0.99942 +support2 text: 原告于2013年8月13日离开被告家,分居至今。 label: 婚后分居 score: 0.99916 +support3 text: 2014年4月,被告外出务工,双方分居至今。 label: 婚后分居 score: 0.99902 ... ``` diff --git a/examples/language_model/glm/utils.py b/examples/language_model/glm/utils.py index f344818586e6..e5ec28374d50 100644 --- a/examples/language_model/glm/utils.py +++ b/examples/language_model/glm/utils.py @@ -20,12 +20,12 @@ from paddle import Tensor from paddle.optimizer.lr import LambdaDecay -from paddlenlp.trainer import Trainer -from paddlenlp.transformers.generation_utils import ( +from paddlenlp.generation.logits_process import ( LogitsProcessorList, MinLengthLogitsProcessor, NoRepeatNGramLogitsProcessor, ) +from paddlenlp.trainer import Trainer class GLMTrainer(Trainer): diff --git a/examples/language_model/gpt-3/dygraph/modeling.py b/examples/language_model/gpt-3/dygraph/modeling.py index 35fee45b76fc..14e5e96b9541 100644 --- a/examples/language_model/gpt-3/dygraph/modeling.py +++ b/examples/language_model/gpt-3/dygraph/modeling.py @@ -33,9 +33,9 @@ from paddle.nn.layer.transformer import _convert_param_attr_to_list import paddlenlp +from paddlenlp.generation.logits_process import LogitsProcessorList from paddlenlp.trainer.argparser import strtobool from paddlenlp.transformers import PretrainedModel, register_base_model -from paddlenlp.transformers.generation_utils import LogitsProcessorList from paddlenlp.transformers.model_outputs import ( CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, diff --git a/examples/language_model/gpt-3/static/args.py b/examples/language_model/gpt-3/static/args.py index 0f336d89aa46..9e3a88380d10 100644 --- a/examples/language_model/gpt-3/static/args.py +++ b/examples/language_model/gpt-3/static/args.py @@ -15,6 +15,7 @@ import argparse import paddle + from paddlenlp.utils.log import logger diff --git a/examples/text_classification/ernie_doc/predict.py b/examples/text_classification/ernie_doc/predict.py index b51422fd229c..a948acda74b9 100644 --- a/examples/text_classification/ernie_doc/predict.py +++ b/examples/text_classification/ernie_doc/predict.py @@ -12,20 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import paddle +from functools import partial + import numpy as np -from paddlenlp.utils.env import PPNLP_HOME -from paddlenlp.utils.log import logger -from paddlenlp.taskflow.utils import dygraph_mode_guard -from modeling import ErnieDocForSequenceClassification -from paddlenlp.transformers import ErnieDocTokenizer, ErnieDocBPETokenizer -from paddlenlp.datasets import load_dataset -from data import ClassifierIterator, ImdbTextPreprocessor, HYPTextPreprocessor, to_json_file +import paddle import paddle.nn as nn +from data import ( + ClassifierIterator, + HYPTextPreprocessor, + ImdbTextPreprocessor, + to_json_file, +) +from modeling import ErnieDocForSequenceClassification from train import init_memory -from functools import partial -import argparse + +from paddlenlp.datasets import load_dataset +from paddlenlp.taskflow.utils import dygraph_mode_guard +from paddlenlp.transformers import ErnieDocBPETokenizer, ErnieDocTokenizer +from paddlenlp.utils.env import PPNLP_HOME +from paddlenlp.utils.log import logger # yapf: disable parser = argparse.ArgumentParser() diff --git a/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py b/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py index a4859890760a..dfba058061fa 100644 --- a/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py +++ b/examples/text_summarization/prophetnet/evaluate/cnndm/bs_pyrouge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function, unicode_literals, division +from __future__ import division, print_function, unicode_literals import codecs import os @@ -28,6 +28,7 @@ from ConfigParser import ConfigParser import logging + from pyrouge.utils import log from pyrouge.utils.file_utils import verify_dir @@ -640,6 +641,7 @@ def __get_config_path(self): if __name__ == "__main__": import argparse + from utils.argparsers import rouge_path_parser parser = argparse.ArgumentParser(parents=[rouge_path_parser]) diff --git a/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py b/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py index df4cbbecbcb2..ecaaf8089e05 100644 --- a/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py +++ b/examples/text_summarization/prophetnet/evaluate/gigaword/bs_pyrouge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function, unicode_literals, division +from __future__ import division, print_function, unicode_literals import codecs import logging @@ -640,6 +640,7 @@ def __get_config_path(self): if __name__ == "__main__": import argparse + from utils.argparsers import rouge_path_parser parser = argparse.ArgumentParser(parents=[rouge_path_parser]) diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index e50ba917a036..b3df32fbf0ba 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -38,6 +38,7 @@ datasets, embeddings, experimental, + generation, layers, losses, metrics, diff --git a/paddlenlp/generation/__init__.py b/paddlenlp/generation/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/paddlenlp/generation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlenlp/generation/logits_process.py b/paddlenlp/generation/logits_process.py new file mode 100644 index 000000000000..ca21d0056c2f --- /dev/null +++ b/paddlenlp/generation/logits_process.py @@ -0,0 +1,352 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from abc import ABC +from typing import List + +import paddle + + +class LogitsProcessorList(List): + def __call__(self, input_ids, logits, **kwargs): + for processor in self: + processor_args = inspect.signature(processor.__call__).parameters + if len(processor_args) > 2: + assert all( + arg in kwargs for arg in list(processor_args.keys())[2:] + ), f"The parameters don't match for {processor.__class__}" + logits = processor(input_ids, logits, **kwargs) + else: + logits = processor(input_ids, logits) + return logits + + +class LogitsProcessor(ABC): + """ + Abstract base class for all logit processors that can be applied during + generation. + """ + + def __call__(self, input_ids, logits): + raise NotImplementedError( + f"{self.__class__} is an abstract class. " "Only classes inheriting this class can be called." + ) + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + Enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (int): The minimum length of generation sequence. + eos_token_id (int): The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length, eos_token_id): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError("`min_length` should be a positive integer, but get {}".format(min_length)) + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError("`eos_token_id` should be a positive integer, but get {}".format(eos_token_id)) + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids, logits): + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + logits[:, self.eos_token_id] = -float("inf") + return logits + + +class RepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + Enforcing an exponential penalty on repeated sequences. + + Args: + repetition_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. See `this paper + `__ for more details. + """ + + def __init__(self, penalty: float): + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + + def __call__(self, input_ids, logits): + score = paddle.index_sample(logits, input_ids) + score = paddle.where(score < 0, score * self.penalty, score / self.penalty) + input_ids = input_ids + paddle.arange(logits.shape[0]).unsqueeze(-1) * logits.shape[-1] + outputs = paddle.scatter(logits.flatten(), input_ids.flatten(), score.flatten()).reshape(logits.shape) + return outputs + + +def _get_ngrams(ngram_size, prev_input_ids, num_hypos): + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + return generated_ngrams + + +def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - ngram_size + ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) + return banned_ngrams.get(ngram_idx, []) + + +def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len): + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + + generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) + + banned_tokens = [ + _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) + for hypo_idx in range(num_hypos) + ] + return banned_tokens + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of n-grams. See + [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + Args: + ngram_size (`int`): + All ngrams of size `ngram_size` can only occur once. + """ + + def __init__(self, ngram_size): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids, scores): + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + +class HammingDiversityLogitsProcessor(LogitsProcessor): + """ + This `LogitsProcessor` enforces diverse beam search. Note that this logits + processor is only effective for `group_beam_search`. See + `this paper `__ for more details. + + Args: + diversity_rate (float): This value is subtracted from a beam's score if + it generates a token same as any beam from other group at a particular + time. + num_beams (int): Number of beams used for group beam search. + num_beam_groups (int): Number of groups to divide `num_beams` into in order + to ensure diversity among different groups of beams. + """ + + def __init__(self, diversity_rate, num_beams, num_beam_groups): + if not isinstance(diversity_rate, float) or (not diversity_rate > 0.0): + raise ValueError("`diversity_rate` should be a float strictly larger than 0.") + self._diversity_rate = diversity_rate + if not isinstance(num_beams, int) or num_beams < 2: + raise ValueError("`num_beams` should be an integer strictly larger than 1.") + self._num_beams = num_beams + if not isinstance(num_beam_groups, int) or num_beam_groups < 2: + raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") + self._num_sub_beams = num_beams // num_beam_groups + + def __call__(self, input_ids, scores, current_tokens, beam_group_idx): + batch_size = current_tokens.shape[0] // self._num_beams + group_start_idx = beam_group_idx * self._num_sub_beams + group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) + group_size = group_end_idx - group_start_idx + vocab_size = scores.shape[-1] + + if group_start_idx == 0: + return scores + + for batch_idx in range(batch_size): + previous_group_tokens = current_tokens[ + batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx + ] + token_frequency = paddle.bincount(previous_group_tokens, minlength=vocab_size) + scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_rate * token_frequency + + return scores + + +class ForcedBOSTokenLogitsProcessor(LogitsProcessor): + """ + This `LogitsProcessor` enforces the first generated token to be the selected `forced_bos_token`. + + Args: + forced_bos_token_id (:obj:`int`): + The id of the token to to be generated as the first token. + """ + + def __init__(self, forced_bos_token_id): + self.forced_bos_token_id = forced_bos_token_id + + def __call__(self, input_ids, scores): + cur_len = input_ids.shape[-1] + if cur_len == 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.forced_bos_token_id]] = -float("inf") + scores[:, self.forced_bos_token_id] = 0 + return scores + + +class ForcedEOSTokenLogitsProcessor(LogitsProcessor): + """ + This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`. + + Args: + max_length (int): The maximum length of the sequence to be generated. + forced_eos_token_id (int): The id of the token to to be generated as the last token. + """ + + def __init__(self, max_length, forced_eos_token_id): + self.max_length = max_length + self.forced_eos_token_id = forced_eos_token_id + + def __call__(self, input_ids, scores): + cur_len = input_ids.shape[-1] + if cur_len == self.max_length - 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.forced_eos_token_id]] = -float("inf") + scores[:, self.forced_eos_token_id] = 0 + return scores + + +class LogitsWarper: + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + def __call__(self, input_ids, scores): + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class TemperatureLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] for temperature (exponential scaling output probability distribution). + + Args: + temperature (`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: float): + if not isinstance(temperature, float) or not (temperature > 0): + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor): + scores = scores / self.temperature + return scores + + +class TopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = max(top_k, min_tokens_to_keep) + self.filter_value = filter_value + + def __call__(self, input_ids, probs): + top_k = min(self.top_k, probs.shape[-1]) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + topk_probs, _ = paddle.topk(probs, k=top_k) + + # NOTE: probs need to be float32, otherwise, paddle.full_like will do truncation + probs = probs.astype("float32") + probs = paddle.where( + probs < topk_probs[:, -1:], paddle.full_like(probs, self.filter_value, dtype="float32"), probs + ) + return probs + + +class TopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + + Args: + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("inf"), min_tokens_to_keep: int = 1): + top_p = float(top_p) + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids, probs): + sorted_logits = paddle.sort(probs, descending=False) + sorted_indices = paddle.argsort(probs, descending=False) + cumulative_probs = paddle.nn.functional.softmax(sorted_logits, axis=-1).cumsum(axis=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) + + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 + + sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64") + sorted_indices = sorted_indices + paddle.arange(probs.shape[0]).unsqueeze(-1) * probs.shape[-1] + condition = paddle.scatter( + sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten() + ) + condition = paddle.cast(condition, "bool").reshape(probs.shape) + + # NOTE: probs need to be float32, otherwise, paddle.full_like will do truncation + probs = probs.astype("float32") + probs = paddle.where(condition, paddle.full_like(probs, self.filter_value, dtype="float32"), probs) + return probs diff --git a/paddlenlp/generation/stopping_criteria.py b/paddlenlp/generation/stopping_criteria.py new file mode 100644 index 000000000000..32447b637914 --- /dev/null +++ b/paddlenlp/generation/stopping_criteria.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import warnings +from abc import ABC +from copy import deepcopy +from typing import Optional + +import paddle + + +class StoppingCriteria(ABC): + """ + Abstract base class for all stopping criteria that can be applied during + generation. + """ + + def __call__(self, input_ids: paddle.Tensor, logits: paddle.Tensor, **kwargs): + raise NotImplementedError(f"{self.__class__} is an abstract class. " "StoppingCriteria needs to be subclassed") + + +class MaxTimeCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the + time will start being counted when you initialize this function. You can override this by passing an + `initial_time`. + + Args: + max_time (`float`): + The maximum allowed time in seconds for the generation. + initial_time (`float`, *optional*, defaults to `time.time()`): + The start of the generation allowed time. + """ + + def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): + self.max_time = max_time + self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp + + def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor, **kwargs) -> bool: + return time.time() - self.initial_timestamp > self.max_time + + +class MaxLengthCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep + in mind for decoder-only type of transformers, [this will include the initial prompted tokens]. + + Args: + max_length (`int`): + The maximum length that the output sequence can have in number of tokens. + """ + + def __init__(self, max_length: int): + self.max_length = max_length + + def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor, **kwargs) -> bool: + return input_ids.shape[-1] >= self.max_length + + +class StoppingCriteriaList(list): + def __call__(self, input_ids: paddle.Tensor, scores: paddle.Tensor, **kwargs): + return any(criteria(input_ids, scores) for criteria in self) + + @property + def max_length(self): + for stopping_criterium in self: + if isinstance(stopping_criterium, MaxLengthCriteria): + return stopping_criterium.max_length + return None + + +def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: + stopping_max_length = stopping_criteria.max_length + new_stopping_criteria = deepcopy(stopping_criteria) + if stopping_max_length is not None and stopping_max_length != max_length: + warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) + elif stopping_max_length is None: + new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) + return new_stopping_criteria diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index 44fa88887d48..3b2c449b11f8 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -14,9 +14,7 @@ # limitations under the License. from __future__ import annotations -import inspect -from abc import ABC -from typing import List, Union +from typing import Optional, Union import paddle import paddle.nn as nn @@ -30,13 +28,32 @@ from paddle.fluid.layers.utils import map_structure from paddle.fluid.dygraph.base import in_declarative_mode +from paddle.nn import MultiHeadAttention from paddlenlp.utils.log import logger -from .model_outputs import ModelOutput +from .model_outputs import CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput __all__ = ["GenerationMixin"] +from paddlenlp.generation.logits_process import ( + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) +from paddlenlp.generation.stopping_criteria import ( + MaxLengthCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) + def get_unfinished_flag( input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]] @@ -343,8 +360,8 @@ def prepare_seq_len_for_generation(input_ids, pad_token_id, eos_token_id): seq_len = paddle.full((input_ids.shape[0], 1), input_ids.shape[1], dtype="int64") return seq_len + @staticmethod def get_logits_processor( - self, min_length=None, max_length=None, eos_token_id=None, @@ -392,34 +409,70 @@ def get_logits_processor( return processors @staticmethod - def expand_inputs_for_generation(input_ids, expand_size, attention_mask=None, **model_kwargs): + def get_logits_warper( + temperature=None, + top_k=None, + top_p=None, + num_beams=1, + ): + # instantiate warpers list + warpers = LogitsProcessorList() + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + + min_tokens_to_keep = 2 if num_beams > 1 else 1 + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=min_tokens_to_keep)) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)) + # TODO + # Add more pre_processing for distribution during generation - index = paddle.tile(paddle.arange(paddle.shape(input_ids)[0]).unsqueeze(-1), [1, expand_size]).reshape([-1]) + return warpers - input_ids = paddle.gather(input_ids, index) + @staticmethod + def get_stopping_criteria(max_length: Optional[int], stopping_criteria: Optional[StoppingCriteriaList]): + criteria = StoppingCriteriaList() + if max_length is not None: + criteria.append(MaxLengthCriteria(max_length=max_length)) - if attention_mask is not None: - model_kwargs["attention_mask"] = paddle.gather(attention_mask, index) + if stopping_criteria is not None: + custom_criteria = StoppingCriteriaList() + custom_criteria_type = [type(cr) for cr in stopping_criteria] - if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index) + for criterium in criteria: + if type(criterium) not in custom_criteria_type: + custom_criteria.append(criterium) + custom_criteria.extend(stopping_criteria) - if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: - position_ids = model_kwargs["position_ids"] - model_kwargs["position_ids"] = paddle.gather(position_ids, index) + return custom_criteria + else: + return criteria - if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None: - seq_len = model_kwargs["seq_len"] - model_kwargs["seq_len"] = paddle.gather(seq_len, index) + @staticmethod + def expand_inputs_for_generation( + input_ids: Optional[paddle.Tensor] = None, + expand_size: int = 1, + is_encoder_decoder: bool = False, + **model_kwargs, + ): + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" - if "encoder_output" in model_kwargs and model_kwargs["encoder_output"] is not None: - encoder_output = model_kwargs["encoder_output"] - model_kwargs["encoder_output"] = paddle.gather(encoder_output, index) + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], paddle.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, axis=0) + return dict_to_expand - if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None: - role_ids = model_kwargs["role_ids"] - model_kwargs["role_ids"] = paddle.gather(role_ids, index) + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, axis=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) return input_ids, model_kwargs @@ -434,9 +487,11 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder # update cache if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): model_kwargs["cache"] = outputs[1] + model_kwargs["past_key_values"] = outputs[1] if isinstance(outputs, ModelOutput) and "past_key_values" in outputs: model_kwargs["cache"] = outputs.past_key_values + model_kwargs["past_key_values"] = outputs.past_key_values # update token_type_ids with last value if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None: @@ -540,7 +595,6 @@ def get_decoder_start_token_id(self, decoder_start_token_id=None, bos_token_id=N def prepare_inputs_for_generation(self, input_ids, **kwargs): # Implement in subclasses for custom behavior to prepare inputs in the # generate method. - return {"input_ids": input_ids} def adjust_logits_during_generation(self, logits): @@ -571,32 +625,33 @@ def _build_fast(self, kwargs): @paddle.no_grad() def generate( self, - input_ids=None, - attention_mask=None, - position_ids=None, - max_length=20, - min_length=0, - decode_strategy="greedy_search", - temperature=1.0, - top_k=0, - top_p=1.0, - repetition_penalty=1.0, - num_beams=1, - num_beam_groups=1, - length_penalty=0.0, - early_stopping=False, - bos_token_id=None, - eos_token_id=None, - pad_token_id=None, - decoder_start_token_id=None, - forced_bos_token_id=None, - forced_eos_token_id=None, - no_repeat_ngram_size=None, - num_return_sequences=1, - diversity_rate=0.0, - use_cache=True, - use_fast=False, - use_fp16_decoding=False, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + max_length: Optional[int] = 20, + min_length: Optional[int] = 0, + decode_strategy: Optional[str] = "greedy_search", + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + penalty_alpha: Optional[float] = None, + num_beams: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, + length_penalty: Optional[int] = None, + early_stopping: Optional[bool] = False, + bos_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, + no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = 1, + diversity_rate: Optional[float] = 0.0, + use_cache: Optional[bool] = True, + use_fast: Optional[bool] = False, + use_fp16_decoding: Optional[bool] = False, **model_kwargs ): r""" @@ -725,6 +780,51 @@ def generate( print(response) # 是的 + .. code-block:: + + # Generate the sequence by using "contrastive_search" strategy + ids, scores = model.generate( + input_ids=inputs['input_ids'], + token_type_ids=inputs['token_type_ids'], + position_ids=inputs['position_ids'], + attention_mask=inputs['attention_mask'], + decode_strategy="contrastive_search", + top_k=5, + penalty_alpha=0.85, + max_length=30) + print(ids.shape, scores.shape) + # [1, 2] [1, 1] + sequence_ids = ids.numpy().tolist()[0] + sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)] + response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False) + print(response) + # 谢谢 + + .. code-block:: + # Directly use model.contrastive_search() + + from paddlenlp.generation.stopping_criteria import StoppingCriteriaList, MaxLengthCriteria + [Note:] If you directly use model.contrastive_search, max_length = max_length + input_ids.shape[-1] + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=30)]) + + # Generate the sequence by using "greedy_search" strategy + ids, scores = model.contrastive_search( + input_ids=inputs['input_ids'], + token_type_ids=inputs['token_type_ids'], + position_ids=inputs['position_ids'], + attention_mask=inputs['attention_mask'], + penalty_alpha=0.85, + top_k=5, + stopping_criteria=stopping_criteria + ) + print(ids.shape, scores.shape) + # [1, 2] [1, 1] + sequence_ids = ids.numpy().tolist()[0] + sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)] + response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False) + print(response) + # 谢谢 + .. code-block:: # Generate 2 sequences by using "sampling" strategy (top_k=5) @@ -771,6 +871,7 @@ def generate( "greedy_search", "sampling", "beam_search", + "contrastive_search", ], "`decode_strategy` must be one of 'greedy_search', 'sampling' or 'beam_search' but received {}.".format( decode_strategy ) @@ -881,6 +982,7 @@ def generate( model_kwargs["use_cache"] = use_cache + # max_len includes prefix length if is_tracing: min_len = input_ids.shape[-1] max_len = input_ids.shape[-1] @@ -891,6 +993,7 @@ def generate( min_len = input_len + min_length max_len = input_len + max_length + # prepare distribution pre_processing samplers logits_processors = self.get_logits_processor( min_length=min_len if min_length > 0 else None, max_length=max_len, @@ -910,15 +1013,52 @@ def generate( if "logits_processors" in model_kwargs: model_kwargs.pop("logits_processors") + # prepare stopping criteria + stopping_criteria = self.get_stopping_criteria( + max_length=max_len, + stopping_criteria=model_kwargs["stopping_criteria"] + if "stopping_criteria" in model_kwargs + and isinstance(model_kwargs["stopping_criteria"], StoppingCriteriaList) + else None, + ) + if "stopping_criteria" in model_kwargs: + model_kwargs.pop("stopping_criteria") + + # go into different generation modes if decode_strategy == "greedy_search": if num_return_sequences > 1: raise ValueError( "`num_return_sequences` has to be 1, but is {} " "when doing greedy search.".format(num_return_sequences) ) - + logits_warper = self.get_logits_warper(temperature=temperature) return self.greedy_search( - input_ids, logits_processors, max_len, pad_token_id, eos_token_id, **model_kwargs + input_ids=input_ids, + logits_processors=logits_processors, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + + elif decode_strategy == "contrastive_search": + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing" + " contrastive search." + ) + logits_warper = self.get_logits_warper(temperature=temperature) + return self.contrastive_search( + input_ids=input_ids, + top_k=top_k, + penalty_alpha=penalty_alpha, + logits_processors=logits_processors, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, ) elif decode_strategy == "sampling": @@ -926,29 +1066,27 @@ def generate( input_ids, model_kwargs = self.expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, **model_kwargs ) + # top_k and top_p can combine to use + logits_warper = self.get_logits_warper(temperature=temperature, top_k=top_k, top_p=top_p) if is_tracing: return self.sample_d2s( - input_ids, - logits_processors, - max_len, - pad_token_id, - eos_token_id, - top_k, - top_p, - temperature, + input_ids=input_ids, + logits_processors=logits_processors, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, **model_kwargs, ) else: return self.sample( - input_ids, - logits_processors, - max_len, - pad_token_id, - eos_token_id, - top_k, - top_p, - temperature, + input_ids=input_ids, + logits_processors=logits_processors, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, **model_kwargs, ) @@ -1016,16 +1154,29 @@ def generate( **model_kwargs, ) - def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id, eos_token_id, **model_kwargs): - model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) + def greedy_search( + self, + input_ids: paddle.Tensor, + logits_processors: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if max_length is not None: + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) batch_size, cur_len = input_ids.shape origin_len = cur_len unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool") scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) - while cur_len < max_length: - + while True: # prepare model inputs & get model output model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1044,11 +1195,19 @@ def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id, # pre-process distribution next_token_logits = self.adjust_logits_during_generation(next_token_logits) next_tokens_scores = logits_processors(input_ids, next_token_logits) + + # record score before warper + origin_probs = F.softmax(next_tokens_scores) + origin_probs = paddle.log(origin_probs) + + next_tokens_scores = logits_warper(input_ids, next_tokens_scores) + # greedy probs = F.softmax(next_tokens_scores) probs = paddle.log(probs) + next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1) - next_scores = paddle.index_sample(probs.astype("float32"), next_tokens) + next_scores = paddle.index_sample(origin_probs.astype("float32"), next_tokens) if eos_token_id is not None: next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) @@ -1061,41 +1220,324 @@ def greedy_search(self, input_ids, logits_processors, max_length, pad_token_id, if eos_token_id is not None: unfinished_flag = get_unfinished_flag(input_ids, unfinished_flag, eos_token_id) + unfinished_flag = get_unfinished_flag(input_ids, unfinished_flag, eos_token_id) # Stop when there is a in all sentences - if not paddle.any(unfinished_flag): + if not paddle.any(unfinished_flag) or stopping_criteria(input_ids, scores): break model_kwargs = self.update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) return input_ids[:, origin_len:], scores - def sample( + def contrastive_search( self, - input_ids, - logits_processors, - max_length, - pad_token_id, - eos_token_id, - top_k=None, - top_p=None, - temperature=None, - min_tokens_to_keep=1, + input_ids: paddle.Tensor, + top_k: Optional[int] = 1, + penalty_alpha: Optional[float] = 0, + logits_processors: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, **model_kwargs ): model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) + model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if max_length is not None: + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) batch_size, cur_len = input_ids.shape origin_len = cur_len unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool") scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) - while cur_len < max_length: - # prepare model inputs & get model output + # If the first step in the loop, encode all the prefix and obtain (1) past_key_values + # (2) last_hidden_states (3) logit_for_next_step (4) update model kwargs for the next step + while True: + if model_kwargs.get("cache") is None: + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + outputs = self(**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) + + # last decoder hidden states will be used to compute the degeneration penalty (cosine simlarity with + # pervious tokens) + if self.config.is_encoder_decoder: + last_hidden_states = outputs.decoder_hidden_states[-1] + else: + last_hidden_states = outputs.hidden_states[-1] + + # next logit for contrastive search to select top-k candidate tokens + if isinstance(outputs, tuple): + logits = outputs[0] + elif isinstance(outputs, ModelOutput): + logits = outputs.logits + else: + logits = outputs + + logit_for_next_step = logits[:, -1, :] + + model_kwargs = self.update_model_kwargs_for_generation( + outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # exapnds model inputs top_k times, for batched forward passes (akin to beam search) + _, model_kwargs = self.expand_inputs_for_generation( + input_ids=input_ids, expand_size=top_k, **model_kwargs + ) + + past_key_values = model_kwargs.get("cache") + + if past_key_values is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif not isinstance(past_key_values[0], (tuple, paddle.Tensor)): + if isinstance(past_key_values[0][0], tuple) and past_key_values[0][0][0].shape[0] != batch_size: + # encoder-decoder arch like bart, cache is tuple (incremental_cache, static_cache) + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + elif not isinstance(past_key_values[0][0], tuple) and past_key_values[0][0].shape[0] != batch_size: + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + # contrastive search main logic starts: + # contrastive search decoding consists of two steps: (1) candidate tokens recall (2) candidate rerank by degeneration penalty + logit_for_next_step = self.adjust_logits_during_generation(logit_for_next_step) + logit_for_next_step = logits_processors(input_ids, logit_for_next_step) + + origin_probs = F.softmax(logit_for_next_step) + origin_probs = paddle.log(origin_probs) + + logit_for_next_step = logits_warper(input_ids, logit_for_next_step) + + next_probs = F.softmax(logit_for_next_step, axis=-1) + + top_k_probs, top_k_ids = paddle.topk(next_probs, k=top_k, axis=-1) + + new_key_values = () + if self.config.is_encoder_decoder: + if isinstance(model_kwargs["cache"][0][0], tuple): + for layer in model_kwargs["cache"]: + even = 0 + items = () + for item in layer: + if even % 2 == 0: + k = item[0].repeat_interleave(top_k, axis=0) + v = item[1].repeat_interleave(top_k, axis=0) + if even and even % 2: + static_k = item[0].repeat_interleave(top_k, axis=0) + static_v = item[1].repeat_interleave(top_k, axis=0) + + items += ( + ( + MultiHeadAttention.Cache(k, v), + MultiHeadAttention.StaticCache(static_k, static_v), + ), + ) + + even += 1 + new_key_values += items + else: + for layer in model_kwargs["cache"]: + items = () + for item in layer: + items += (item.repeat_interleave(top_k, axis=0),) + new_key_values += (items,) + model_kwargs["cache"] = new_key_values + else: + new_key_values = () + for layer in model_kwargs["cache"]: + even = 0 + items = () + for item in layer: + if even % 2 == 0: + k = item.repeat_interleave(top_k, axis=0) + if even and even % 2: + v = item.repeat_interleave(top_k, axis=0) + + items += (MultiHeadAttention.Cache(k, v),) + even += 1 + new_key_values += items + model_kwargs["cache"] = new_key_values + + model_kwargs["past_key_values"] = model_kwargs["cache"] + # compute the candidate tokens by the langugae model and collects their hidden states + tmp = paddle.reshape(top_k_ids, (-1, 1)) + + next_model_inputs = self.prepare_inputs_for_generation(tmp, **model_kwargs) + + outputs = self(**next_model_inputs, return_dict=True, output_hidden_states=True) + + if isinstance(outputs, tuple): + logits = outputs[0] + next_past_key_values = outputs[1] + elif isinstance(outputs, ModelOutput): + logits = outputs.logits + next_past_key_values = outputs.past_key_values + else: + logits = outputs + + # next_past_key_values = outputs.past_key_values + logits = logits[:, -1, :] + + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1][:, -1, :].unsqueeze(1) + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1][:, -1, :].unsqueeze(1) + full_hidden_states = outputs.hidden_states + + context_hidden = last_hidden_states.repeat_interleave(top_k, axis=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence + selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + + # prepare for the next step: (1) next_token_id (2) past_key_values (3) last_hidden_states for computing + # the degeneration penalty (4) logits for selecting next top-k candidates (5) selected tokens socres + # (model confidence minus degeneration penalty) (6) decoder hidden state + + index = selected_idx.reshape((len(top_k_ids),) + (1,) * (len(top_k_ids.shape) - 1)) + next_tokens = paddle.take_along_axis(top_k_ids, index, axis=1) + + next_scores = paddle.index_sample(origin_probs, next_tokens) + + next_hidden = paddle.stack(paddle.split(next_hidden.squeeze(axis=1), top_k), axis=1) + + index = selected_idx.reshape((len(next_hidden),) + (1,) * (len(next_hidden.shape) - 1)) + next_hidden = paddle.take_along_axis(next_hidden, index, axis=1).squeeze(1) + + last_hidden_states = paddle.concat([last_hidden_states, next_hidden.unsqueeze(1)], axis=1) + + next_decoder_hidden_states = () + for layer in full_hidden_states: + tmp = paddle.stack(paddle.split(layer, top_k), axis=1) + index = selected_idx.reshape((len(tmp),) + (1,) * (len(tmp.shape) - 1)) + layer = paddle.take_along_axis(tmp, index, axis=1).squeeze(1) + next_decoder_hidden_states += (layer,) + + # select the past_key_value + if self.config.is_encoder_decoder and isinstance(next_past_key_values[0][0], tuple): + new_key_values = () + for layer in next_past_key_values: + items = () + even = 0 + # item is either the key or the value matrix + for item in layer: + if even % 2 == 0: + k = paddle.stack(paddle.split(item[0], top_k, axis=0), axis=1) + index = selected_idx.reshape((len(k),) + (1,) * (len(k.shape) - 1)) + k = paddle.take_along_axis(k, index, axis=1).squeeze(1) + v = paddle.stack(paddle.split(item[1], top_k, axis=0), axis=1) + index = selected_idx.reshape((len(v),) + (1,) * (len(v.shape) - 1)) + v = paddle.take_along_axis(v, index, axis=1).squeeze(1) + + if even and even % 2: + static_k = paddle.stack(paddle.split(item[0], top_k, axis=0), axis=1) + index = selected_idx.reshape((len(static_k),) + (1,) * (len(static_k.shape) - 1)) + static_k = paddle.take_along_axis(static_k, index, axis=1).squeeze(1) + static_v = paddle.stack(paddle.split(item[1], top_k, axis=0), axis=1) + index = selected_idx.reshape((len(static_v),) + (1,) * (len(static_v.shape) - 1)) + static_v = paddle.take_along_axis(static_v, index, axis=1).squeeze(1) + + items += ( + MultiHeadAttention.Cache(k, v), + MultiHeadAttention.StaticCache(static_k, static_v), + ) + + even += 1 + new_key_values += (items,) + next_past_key_values = new_key_values + else: + new_key_values = () + for layer in next_past_key_values: + items = () + # item is either the key or the value matrix + for item in layer: + item = paddle.stack(paddle.split(item, top_k, axis=0), axis=1) + index = selected_idx.reshape((len(item),) + (1,) * (len(item.shape) - 1)) + item = paddle.take_along_axis(item, index, axis=1).squeeze(1) + items += (item,) + new_key_values += (items,) + next_past_key_values = new_key_values + + tmp = paddle.stack(paddle.split(logits, top_k), axis=1) + index = selected_idx.reshape((len(tmp),) + (1,) * (len(tmp.shape) - 1)) + logit_for_next_step = paddle.take_along_axis(tmp, index, axis=1).squeeze(1) + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration. + if self.config.is_encoder_decoder: + outputs = Seq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + ) + else: + outputs = CausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states, + ) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) + + input_ids = paddle.concat([input_ids, next_tokens], axis=1) + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag) + cur_len += 1 + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_flag = paddle.logical_and(unfinished_flag, next_tokens != eos_token_id) + + # Stop when there is a in all sentences + if not paddle.any(unfinished_flag) or stopping_criteria(input_ids, scores): + break + + return input_ids[:, origin_len:], scores + + def sample( + self, + input_ids: paddle.Tensor, + logits_processors: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + + batch_size, cur_len = input_ids.shape + origin_len = cur_len + unfinished_flag = paddle.full([batch_size, 1], True, dtype="bool") + scores = paddle.full([batch_size, 1], 0.0, dtype=paddle.get_default_dtype()) + + while True: + # prepare model inputs and get model output model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs) @@ -1106,67 +1548,70 @@ def sample( else: logits = outputs - # [batch_size, vocab_size] logits = logits[:, -1, :] # pre-process distribution logits = self.adjust_logits_during_generation(logits) logits = logits_processors(input_ids, logits) - # sample + # record score origin_probs = F.softmax(logits) origin_probs = paddle.log(origin_probs) - if temperature is not None and temperature != 1.0: - logits = logits / temperature - probs = F.softmax(logits) - if top_k is not None and top_k != 0: - probs = TopKProcess(probs, top_k, min_tokens_to_keep) - if top_p is not None and top_p < 1.0: - probs = TopPProcess(probs, top_p, min_tokens_to_keep) + # adjust distribution + logits = logits_warper(input_ids, logits) + + # sample + probs = F.softmax(logits) # multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852 - if probs.dtype == paddle.bfloat16 and top_k == 1: + if paddle.get_default_dtype() not in ["float32", "float64"]: probs = probs.astype("float32") - next_tokens = paddle.unsqueeze(paddle.argmax(probs, axis=-1), -1) - else: - next_tokens = paddle.multinomial(probs) + next_tokens = paddle.multinomial(probs) next_scores = paddle.index_sample(origin_probs, next_tokens) + # finished sentences should have their next token be a padding token if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) - scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag) + input_ids = paddle.concat([input_ids, next_tokens], axis=1) + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag) cur_len += 1 - input_ids = paddle.concat([input_ids, next_tokens], axis=1) + # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_flag = paddle.logical_and(unfinished_flag, next_tokens != eos_token_id) # Stop when there is a in all sentences - if not paddle.any(unfinished_flag): + if not paddle.any(unfinished_flag) or stopping_criteria(input_ids, scores): break - model_kwargs = self.update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder - ) + return input_ids[:, origin_len:], scores def sample_d2s( self, - input_ids, - logits_processors, - max_length, - pad_token_id, - eos_token_id, - top_k=None, - top_p=None, - temperature=None, - min_tokens_to_keep=1, + input_ids: paddle.Tensor, + logits_processors: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, **model_kwargs ): logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if max_length is not None: + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) batch_size, cur_len = paddle.shape(input_ids) # used for compute on gpu, avoid memcpy D2H @@ -1210,15 +1655,8 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f origin_probs = F.softmax(logits) origin_probs = paddle.log(origin_probs) - if temperature is not None or temperature != 1.0: - logits = logits / temperature - + logits = logits_warper(input_ids, logits) probs = F.softmax(logits) - if top_k is not None and top_k != 0: - probs = TopKProcess(probs, top_k, min_tokens_to_keep) - if top_p is not None and top_p < 1.0: - probs = TopPProcess(probs, top_p, min_tokens_to_keep) - # multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852 if paddle.get_default_dtype() not in ["float32", "float64"]: probs = probs.astype("float32") @@ -1237,7 +1675,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f unfinished_flag = paddle.logical_and(unfinished_flag, next_tokens != eos_token_id) model_kwargs = self.update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) return input_ids, scores, unfinished_flag, model_kwargs @@ -1253,9 +1691,10 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f # make the shape of attention_mask = (-1, -1, -1, -1) in dy2static. model_kwargs["attention_mask"] = paddle.reshape(attn_mask, paddle.shape(attn_mask)) model_kwargs["cache"] = outputs[1] if isinstance(outputs, tuple) else None - max_length = paddle.full([1], max_length, dtype="int64") - while cur_len < max_length: + # max_length = paddle.full([1], max_length, dtype="int64") + + while True: input_ids, scores, unfinished_flag, model_kwargs = _post_process_( _forward_(**model_kwargs), input_ids, @@ -1268,7 +1707,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f paddle.increment(cur_len) paddle.increment(cur_len_gpu) - if not paddle.any(unfinished_flag): + if not paddle.any(unfinished_flag) or stopping_criteria(input_ids, scores): break return input_ids[:, origin_len:], scores @@ -1286,6 +1725,8 @@ def beam_search( ): model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) + model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) + logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList() batch_size = len(beam_scorer._beam_hyps) @@ -1532,267 +1973,23 @@ def group_beam_search( return pred_ids[:, origin_len:], scores -class LogitsProcessorList(List): - def __call__(self, input_ids, logits, **kwargs): - for processor in self: - processor_args = inspect.signature(processor.__call__).parameters - if len(processor_args) > 2: - assert all( - arg in kwargs for arg in list(processor_args.keys())[2:] - ), f"The parameters don't match for {processor.__class__}" - logits = processor(input_ids, logits, **kwargs) - else: - logits = processor(input_ids, logits) - return logits - - -class LogitsProcessor(ABC): - """ - Abstract base class for all logit processors that can be applied during - generation. - """ - - def __call__(self, input_ids, logits): - raise NotImplementedError( - f"{self.__class__} is an abstract class. " "Only classes inheriting this class can be called." - ) - - -class MinLengthLogitsProcessor(LogitsProcessor): - r""" - Enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (int): The minimum length of generation sequence. - eos_token_id (int): The id of the `end-of-sequence` token. - """ - - def __init__(self, min_length, eos_token_id): - if not isinstance(min_length, int) or min_length < 0: - raise ValueError("`min_length` should be a positive integer, but get {}".format(min_length)) - - if not isinstance(eos_token_id, int) or eos_token_id < 0: - raise ValueError("`eos_token_id` should be a positive integer, but get {}".format(eos_token_id)) - - self.min_length = min_length - self.eos_token_id = eos_token_id - - def __call__(self, input_ids, logits): - cur_len = input_ids.shape[-1] - if cur_len < self.min_length: - logits[:, self.eos_token_id] = -float("inf") - return logits - - -class RepetitionPenaltyLogitsProcessor(LogitsProcessor): - r""" - Enforcing an exponential penalty on repeated sequences. - - Args: - repetition_penalty (float): - The parameter for repetition penalty. 1.0 means no penalty. See `this paper - `__ for more details. - """ - - def __init__(self, penalty: float): - if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") - - self.penalty = penalty - - def __call__(self, input_ids, logits): - score = paddle.index_sample(logits, input_ids) - score = paddle.where(score < 0, score * self.penalty, score / self.penalty) - input_ids = input_ids + paddle.arange(logits.shape[0]).unsqueeze(-1) * logits.shape[-1] - outputs = paddle.scatter(logits.flatten(), input_ids.flatten(), score.flatten()).reshape(logits.shape) - return outputs - - -def _get_ngrams(ngram_size, prev_input_ids, num_hypos): - generated_ngrams = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] - return generated_ngrams - - -def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - ngram_size - ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) - return banned_ngrams.get(ngram_idx, []) - - -def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len): - """Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [[] for _ in range(num_hypos)] - - generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) - - banned_tokens = [ - _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) - for hypo_idx in range(num_hypos) - ] - return banned_tokens - - -class NoRepeatNGramLogitsProcessor(LogitsProcessor): - r""" - [`LogitsProcessor`] that enforces no repetition of n-grams. See - [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). - Args: - ngram_size (`int`): - All ngrams of size `ngram_size` can only occur once. - """ - - def __init__(self, ngram_size): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - self.ngram_size = ngram_size - - def __call__(self, input_ids, scores): - num_batch_hypotheses = scores.shape[0] - cur_len = input_ids.shape[-1] - banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) - - for i, banned_tokens in enumerate(banned_batch_tokens): - scores[i, banned_tokens] = -float("inf") - - return scores - - -class HammingDiversityLogitsProcessor(LogitsProcessor): - """ - This `LogitsProcessor` enforces diverse beam search. Note that this logits - processor is only effective for `group_beam_search`. See - `this paper `__ for more details. - - Args: - diversity_rate (float): This value is subtracted from a beam's score if - it generates a token same as any beam from other group at a particular - time. - num_beams (int): Number of beams used for group beam search. - num_beam_groups (int): Number of groups to divide `num_beams` into in order - to ensure diversity among different groups of beams. - """ - - def __init__(self, diversity_rate, num_beams, num_beam_groups): - if not isinstance(diversity_rate, float) or (not diversity_rate > 0.0): - raise ValueError("`diversity_rate` should be a float strictly larger than 0.") - self._diversity_rate = diversity_rate - if not isinstance(num_beams, int) or num_beams < 2: - raise ValueError("`num_beams` should be an integer strictly larger than 1.") - self._num_beams = num_beams - if not isinstance(num_beam_groups, int) or num_beam_groups < 2: - raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") - self._num_sub_beams = num_beams // num_beam_groups - - def __call__(self, input_ids, scores, current_tokens, beam_group_idx): - batch_size = current_tokens.shape[0] // self._num_beams - group_start_idx = beam_group_idx * self._num_sub_beams - group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) - group_size = group_end_idx - group_start_idx - vocab_size = scores.shape[-1] - - if group_start_idx == 0: - return scores - - for batch_idx in range(batch_size): - previous_group_tokens = current_tokens[ - batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx - ] - token_frequency = paddle.bincount(previous_group_tokens, minlength=vocab_size) - scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_rate * token_frequency - - return scores - - -class ForcedBOSTokenLogitsProcessor(LogitsProcessor): - """ - This `LogitsProcessor` enforces the first generated token to be the selected `forced_bos_token`. - - Args: - forced_bos_token_id (:obj:`int`): - The id of the token to be generated as the first token. +def _ranking_fast( + context_hidden: paddle.Tensor, + next_hidden: paddle.Tensor, + next_top_k_probs: paddle.Tensor, + alpha: float, + beam_width: int, +): """ - - def __init__(self, forced_bos_token_id): - self.forced_bos_token_id = forced_bos_token_id - - def __call__(self, input_ids, scores): - cur_len = input_ids.shape[-1] - if cur_len == 1: - num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i != self.forced_bos_token_id]] = -float("inf") - scores[:, self.forced_bos_token_id] = 0 - return scores - - -class ForcedEOSTokenLogitsProcessor(LogitsProcessor): + Rerank the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens). + Rerank the index of the best candidate for each row in the batch. """ - This `LogitsProcessor` enforces the last generated token to be the selected `forced_eos_token`. - - Args: - max_length (int): The maximum length of the sequence to be generated. - forced_eos_token_id (int): The id of the token to be generated as the last token. - """ - - def __init__(self, max_length, forced_eos_token_id): - self.max_length = max_length - self.forced_eos_token_id = forced_eos_token_id - - def __call__(self, input_ids, scores): - cur_len = input_ids.shape[-1] - if cur_len == self.max_length - 1: - num_tokens = scores.shape[1] - scores[ - :, [i for i in range(num_tokens) if i != self.forced_eos_token_id] - ] = -1e9 # TODO change back to -inf after paddle.topk is fixed - scores[:, self.forced_eos_token_id] = 0 - return scores - - -def TopKProcess(probs, top_k, min_tokens_to_keep): - top_k = min(max(top_k, min_tokens_to_keep), probs.shape[-1]) - # Remove all tokens with a probability less than the last token of the top-k - # cast to float16 to support generation & d2s - if probs.dtype == paddle.bfloat16: - probs = paddle.cast(probs, paddle.float32) - topk_probs, _ = paddle.topk(probs, k=top_k) - topk_probs = paddle.cast(topk_probs, paddle.bfloat16) - else: - topk_probs, _ = paddle.topk(probs, k=top_k) - - probs = paddle.where(probs >= topk_probs[:, -1:], probs, paddle.full_like(probs, 0.0)) - return probs - - -def TopPProcess(probs, top_p, min_tokens_to_keep): - sorted_probs = paddle.sort(probs, descending=True) - sorted_indices = paddle.argsort(probs, descending=True) - cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) - - # Remove tokens with cumulative probs above the top_p, But keep at - # least min_tokens_to_keep tokens - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Set 'min_tokens_to_keep - 1' because the first token is kept - sorted_indices_to_remove[:, : min_tokens_to_keep - 1] = 0 - # Keep the first token - sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64") - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - - # Scatter sorted tensors to original indexing - sorted_indices = sorted_indices + paddle.arange(probs.shape[0]).unsqueeze(-1) * probs.shape[-1] - condition = paddle.scatter( - sorted_indices_to_remove.flatten(), sorted_indices.flatten(), sorted_indices_to_remove.flatten() - ) - condition = paddle.cast(condition, "bool").reshape(probs.shape) - probs = paddle.where(condition, paddle.full_like(probs, 0.0), probs) - return probs + norm_context_hidden = context_hidden / paddle.linalg.norm(context_hidden, axis=2, keepdim=True) + norm_next_hidden = next_hidden / paddle.linalg.norm(next_hidden, axis=2, keepdim=True) + consine_matrix = paddle.matmul(norm_context_hidden, norm_next_hidden.transpose([0, 2, 1])).squeeze(-1) + degeneration_penalty = paddle.max(consine_matrix, axis=-1) + next_top_k_probs = paddle.reshape(next_top_k_probs, (-1,)) + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = paddle.stack(paddle.split(contrastive_score, beam_width), axis=-1) + selected_idx = paddle.argmax(contrastive_score, axis=-1) + return selected_idx diff --git a/pipelines/pipelines/document_stores/milvus2.py b/pipelines/pipelines/document_stores/milvus2.py index 73a9c29c502f..5103dc93a5f0 100644 --- a/pipelines/pipelines/document_stores/milvus2.py +++ b/pipelines/pipelines/document_stores/milvus2.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union - import logging import warnings -import numpy as np +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union +import numpy as np from tqdm import tqdm try: - from pymilvus import FieldSchema, CollectionSchema, Collection, connections, utility + from pymilvus import Collection, CollectionSchema, FieldSchema, connections, utility from pymilvus.client.abstract import QueryResult from pymilvus.client.types import DataType except (ImportError, ModuleNotFoundError) as ie: @@ -29,9 +28,9 @@ _optional_component_not_installed(__name__, "milvus2", ie) -from pipelines.schema import Document -from pipelines.document_stores.sql import SQLDocumentStore from pipelines.document_stores.base import get_batches_from_generator +from pipelines.document_stores.sql import SQLDocumentStore +from pipelines.schema import Document if TYPE_CHECKING: from pipelines.nodes.retriever.base import BaseRetriever diff --git a/tests/generation/__init__.py b/tests/generation/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/generation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py new file mode 100644 index 000000000000..2657a8462eca --- /dev/null +++ b/tests/generation/test_logits_process.py @@ -0,0 +1,323 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import paddle +from paddle import nn + + +def ids_tensor(shape, vocab_size, rng=None, name=None): + # Creates a random int32 tensor of the shape within the vocab size + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return paddle.to_tensor(data=values).reshape(shape) + + +from paddlenlp.generation.logits_process import ( + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + + +class LogitsProcessorTest(unittest.TestCase): + def _get_uniform_logits(self, batch_size: int, length: int): + scores = paddle.ones((batch_size, length)) / length + return scores + + def test_min_length_dist_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + + # check that min length is applied at length 5 + input_ids = ids_tensor((batch_size, 5), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")]) + + # check that min length is not applied anymore at length 15 + input_ids = ids_tensor((batch_size, 15), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertFalse(paddle.isinf(scores_before_min_length).any()) + + def test_temperature_dist_warper(self): + input_ids = None + length = 20 + + scores = self._get_uniform_logits(batch_size=2, length=length) + + # tweak scores to not be uniform anymore + scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch + scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + + # compute softmax + probs = nn.functional.softmax(scores, axis=-1) + + temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5) + temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3) + + warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), axis=-1) + warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), axis=-1) + + # uniform distribution stays uniform + self.assertTrue(paddle.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) + self.assertTrue(paddle.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)) + + # sharp peaks get higher, valleys get lower + self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max()) + self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min()) + + # smooth peaks get lower, valleys get higher + self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max()) + self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min()) + + def test_repetition_penalty_dist_process(self): + input_ids = paddle.to_tensor([[0, 1], [5, 0]]) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + # give values special values + scores[0, 0] = -(1 / vocab_size) + scores[1, 5] = 4 / vocab_size + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) + + scores = rep_penalty_proc(input_ids, scores.clone()) + + # check that values were correctly changed + self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2) + self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2) + + self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) + self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) + + def test_top_k_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create ramp distribution + ramp_logits = paddle.arange(vocab_size).unsqueeze(0).tile((batch_size, 1)) + ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size + ramp_logits = ramp_logits.astype("float32") + top_k_warp = TopKLogitsWarper(3) + + scores = top_k_warp(input_ids, ramp_logits) + + # check that correct tokens are filtered + self.assertListEqual(paddle.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) + self.assertListEqual(paddle.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) + + # check special cases + length = 5 + + logits = self._get_uniform_logits(batch_size=batch_size, length=length) + top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) + scores = top_k_warp_safety_check(input_ids, logits) + # uniform dist is not changed + self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [0, 0]) + + ramp_logits = paddle.arange(length).unsqueeze(0).tile((batch_size, 1)) + ramp_logits = ramp_logits.astype("float32") + scores = top_k_warp_safety_check(input_ids, ramp_logits) + + # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified + self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2]) + + def test_top_p_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = paddle.log(paddle.to_tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) + + top_p_warp = TopPLogitsWarper(0.8) + filtered_dist = paddle.exp(top_p_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= top_p + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = paddle.to_tensor([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]]) + self.assertTrue(paddle.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check edge cases with negative and extreme logits + ramp_logits = paddle.arange(vocab_size).unsqueeze(0).tile((batch_size, 1)) - (vocab_size // 2) + ramp_logits = ramp_logits.astype("float32") + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = top_p_warp(input_ids, ramp_logits) + + # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2]) + + def test_no_repeat_ngram_dist_processor(self): + vocab_size = 3 + batch_size = 2 + + input_ids = paddle.to_tensor([[1, 1, 2, 1], [0, 1, 0, 1]]) + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2) + no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + + # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual( + paddle.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]] + ) + + # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual( + paddle.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]] + ) + + def test_processor_list(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + eos_token_id = 0 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids_comp = input_ids.clone() + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_comp = scores.clone() + + # instantiate all dist processors + min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) + top_k_warp = TopKLogitsWarper(3) + top_p_warp = TopPLogitsWarper(0.8) + no_repeat_proc = NoRepeatNGramLogitsProcessor(2) + + # no processor list + scores = min_dist_proc(input_ids, scores) + scores = temp_dist_warp(input_ids, scores) + scores = rep_penalty_proc(input_ids, scores) + scores = top_k_warp(input_ids, scores) + scores = top_p_warp(input_ids, scores) + scores = no_repeat_proc(input_ids, scores) + + # with processor list + processor = LogitsProcessorList( + [ + min_dist_proc, + temp_dist_warp, + rep_penalty_proc, + top_k_warp, + top_p_warp, + no_repeat_proc, + ] + ) + scores_comp = processor(input_ids, scores_comp) + + # scores should be equal + self.assertTrue(paddle.allclose(scores, scores_comp, atol=1e-3)) + + # input_ids should never be changed + self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) + + def test_hamming_diversity(self): + vocab_size = 4 + num_beams = 2 + num_beam_groups = 2 + + scores = self._get_uniform_logits(num_beams, vocab_size) + # batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1 + # batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1 + current_tokens = paddle.to_tensor([0, 3, 1, 2]) + + diversity_logits_processor = HammingDiversityLogitsProcessor( + diversity_rate=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups + ) + + processed_scores = diversity_logits_processor(None, scores, current_tokens, 1) + + self.assertTrue( + paddle.allclose(processed_scores[0], paddle.to_tensor([-0.7500, 0.2500, 0.2500, 0.2500]), atol=1e-3) + ) + self.assertTrue( + paddle.allclose(processed_scores[1], paddle.to_tensor([0.2500, -0.7500, 0.2500, 0.2500]), atol=1e-3) + ) + + def test_forced_bos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + bos_token_id = 0 + + logits_processor = ForcedBOSTokenLogitsProcessor(forced_bos_token_id=bos_token_id) + + # check that all scores are -inf except the bos_token_id score + input_ids = ids_tensor((batch_size, 1), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertTrue(paddle.isinf(-scores[:, bos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero + + # check that bos_token_id is not forced if current length is greater than 1 + input_ids = ids_tensor((batch_size, 4), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertFalse(paddle.isinf(scores).any()) + + def test_forced_eos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + max_length = 5 + + logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, forced_eos_token_id=eos_token_id) + + # check that all scores are -inf except the eos_token_id when max_length-1 is reached + input_ids = ids_tensor((batch_size, 4), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertTrue(paddle.isinf(-scores[:, eos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero + + # check that eos_token_id is not forced if max_length-1 is not reached + input_ids = ids_tensor((batch_size, 3), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertFalse(paddle.isinf(scores).any()) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py new file mode 100644 index 000000000000..f557dbaba279 --- /dev/null +++ b/tests/generation/test_stopping_criteria.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import time +import unittest + +import paddle + +from paddlenlp.generation.stopping_criteria import ( + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteriaList, + validate_stopping_criteria, +) + + +def ids_tensor(shape, vocab_size, rng=None, name=None): + # Creates a random int32 tensor of the shape within the vocab size + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + return paddle.to_tensor(data=values).reshape(shape) + + +class StoppingCriteriaTestCase(unittest.TestCase): + def _get_tensors(self, length): + batch_size = 3 + vocab_size = 250 + + input_ids = ids_tensor((batch_size, length), vocab_size) + scores = paddle.ones((batch_size, length)) / length + return input_ids, scores + + def test_list_criteria(self): + input_ids, scores = self._get_tensors(5) + + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=10), + MaxTimeCriteria(max_time=0.1), + ] + ) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(9) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(10) + self.assertTrue(criteria(input_ids, scores)) + + def test_max_length_criteria(self): + criteria = MaxLengthCriteria(max_length=10) + + input_ids, scores = self._get_tensors(5) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(9) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(10) + self.assertTrue(criteria(input_ids, scores)) + + def test_max_time_criteria(self): + input_ids, scores = self._get_tensors(5) + + criteria = MaxTimeCriteria(max_time=0.1) + self.assertFalse(criteria(input_ids, scores)) + + criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2) + self.assertTrue(criteria(input_ids, scores)) + + def test_validate_stopping_criteria(self): + validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10) + + with self.assertWarns(UserWarning): + validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11) + + stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) + + self.assertEqual(len(stopping_criteria), 1) diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index 0855d03238c7..e1dfbac7e434 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -19,6 +19,20 @@ import numpy as np import paddle +from paddlenlp.generation.logits_process import ( + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TopKLogitsWarper, + TopPLogitsWarper, +) +from paddlenlp.generation.stopping_criteria import ( + MaxLengthCriteria, + StoppingCriteriaList, +) from paddlenlp.transformers import ( # import gpt model AutoModelForCausalLM, AutoTokenizer, @@ -30,14 +44,6 @@ ) from paddlenlp.transformers.generation_utils import ( BeamSearchScorer, - ForcedBOSTokenLogitsProcessor, - ForcedEOSTokenLogitsProcessor, - HammingDiversityLogitsProcessor, - LogitsProcessorList, - MinLengthLogitsProcessor, - RepetitionPenaltyLogitsProcessor, - TopKProcess, - TopPProcess, get_unfinished_flag, ) from tests.testing_utils import slow @@ -47,13 +53,18 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, + filter_value=-float("inf"), min_tokens_to_keep=1, ): if top_k > 0: - logits = TopKProcess(logits, top_k, min_tokens_to_keep) + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) if 0 <= top_p <= 1.0: - logits = TopPProcess(logits, top_p, min_tokens_to_keep) + logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) return logits @@ -296,6 +307,8 @@ def _sample_generate( input_ids_clone = input_ids.repeat_interleave(num_return_sequences, axis=0) with paddle.no_grad(): + logits_warper = LogitsProcessorList() + logits_warper.append(TopKLogitsWarper(top_k=1, min_tokens_to_keep=1)) output_sample = model.sample( input_ids_clone, attention_mask=attention_mask_clone, @@ -303,7 +316,7 @@ def _sample_generate( logits_processors=logits_processors, pad_token_id=getattr(model, model.base_model_prefix).config["pad_token_id"], eos_token_id=getattr(model, model.base_model_prefix).config["eos_token_id"], - top_k=1, + logits_warper=logits_warper, **process_kwargs, **kwargs, ) @@ -418,6 +431,86 @@ def _group_beam_search_generate( ) return output_generate, output_group_beam_search + def _contrastive_generate( + self, + model, + input_ids, + attention_mask, + max_length, + ): + contrastive_search_kwargs = { + "penalty_alpha": 0.6, + "top_k": 5, + } + + if self.is_encoder_decoder: + max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + eos_token_id=getattr(model, model.base_model_prefix).config["eos_token_id"], + forced_bos_token_id=getattr(getattr(model, model.base_model_prefix).config, "forced_bos_token_id", None), + forced_eos_token_id=getattr(getattr(model, model.base_model_prefix).config, "forced_eos_token_id", None), + max_length=max_length, + plus_length=1 if self.is_encoder_decoder else input_ids.shape[-1], + ) + + kwargs = {} + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + + output_generate = model.generate( + input_ids, + max_length=max_length, + decode_strategy="contrastive_search", + **logits_process_kwargs, + **model_kwargs, + **contrastive_search_kwargs, + ) + + if self.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, + input_ids, + attention_mask, + ) + kwargs["encoder_output"] = encoder_outputs + + with paddle.no_grad(): + model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} + stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length + input_ids.shape[-1])]) + output_contrastive = model.contrastive_search( + input_ids=input_ids, + stopping_criteria=stopping_criteria, + logits_processors=logits_processor, + eos_token_id=getattr(model, model.base_model_prefix).config["eos_token_id"], + pad_token_id=getattr(model, model.base_model_prefix).config["pad_token_id"], + **kwargs, + **model_kwargs, + **contrastive_search_kwargs, + ) + + return output_contrastive, output_generate + + def test_contrastive_generate(self): + + # check `generate()` and `contrastive_search()` are equal + for model_class in self.all_generative_model_classes.keys(): + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + paddle.seed(124) + # NOTE: contrastive search only works with cache on at the moment. + if not hasattr(config, "use_cache"): + return + config.use_cache = True + config.is_decoder = True + + model = self._make_model_instance(config, model_class) + model.eval() + + output_contrastive, output_generate = self._contrastive_generate( + model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length + ) + self.assertListEqual(output_contrastive[0].tolist(), output_generate[0].tolist()) + def test_greedy_generate(self): # check `generate()` and `greedy_search()` are equal for model_class in self.all_generative_model_classes.keys(): @@ -732,11 +825,11 @@ def test_top_k_top_p_filtering(self): ) output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) - non_inf_output = output[output >= -10000] - non_inf_idx = (output >= -10000).nonzero() + non_inf_output = output[output != -float("inf")] + non_inf_idx = (output != -float("inf")).nonzero() self.assertTrue(paddle.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) - self.assertTrue(paddle.all(paddle.eq(non_inf_expected_idx, non_inf_idx))) + self.assertTrue(paddle.all(paddle.equal(non_inf_expected_idx, non_inf_idx))) class GenerationIntegrationTests: @@ -765,6 +858,7 @@ def test_diverse_beam_search(self): # assigned but never used bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + @slow def test_max_length_backward_compat_greedy(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -811,6 +905,55 @@ def test_max_length_backward_compat_greedy(self): **model_kwargs, ) + @slow + def test_max_length_backward_compat_contrastive(self): + + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + + bart_tokenizer = BartTokenizer.from_pretrained("bart-base") + bart_model = BartForConditionalGeneration.from_pretrained("bart-base") + input_ids = paddle.to_tensor(bart_tokenizer(article)["input_ids"]).unsqueeze([0]) + + bart_model.eval() + + max_length = 5 + input_ids = paddle.tile(input_ids, [2, 1]) + + bos_token_id = getattr(bart_model, "bos_token_id", None) + eos_token_id = getattr(bart_model, "eos_token_id", None) + pad_token_id = getattr(bart_model, "pad_token_id", None) + decoder_start_token_id = getattr(bart_model, "decoder_start_token_id", None) + + model_kwargs = {} + + model_kwargs["attention_mask"] = bart_model.prepare_attention_mask_for_generation( + input_ids, pad_token_id, eos_token_id + ) + + bart_model.is_encoder_decoder = hasattr(bart_model, "encoder") and hasattr(bart_model, "decoder") + + model_kwargs = bart_model.prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) + + if "decoder_input_ids" in model_kwargs: + input_ids = model_kwargs.pop("decoder_input_ids") + else: + input_ids = bart_model.prepare_decoder_input_ids_for_generation( + input_ids, decoder_start_token_id, bos_token_id + ) + + model_kwargs["use_cache"] = True + max_length += input_ids.shape[-1] + + bart_model.contrastive_search( + input_ids, + max_length=max_length, + pad_token_id=bart_model.bart.config["pad_token_id"], + eos_token_id=bart_model.bart.config["eos_token_id"], + logits_processors=None, + **model_kwargs, + ) + + @slow def test_max_length_backward_compat_sample(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -858,6 +1001,7 @@ def test_max_length_backward_compat_sample(self): **model_kwargs, ) + @slow def test_max_length_backward_compat_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -911,6 +1055,7 @@ def test_max_length_backward_compat_beam_search(self): **model_kwargs, ) + @slow def test_max_length_backward_compat_group_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -963,6 +1108,7 @@ def test_max_length_backward_compat_group_beam_search(self): **model_kwargs, ) + @slow def test_custom_logits_processor(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -998,7 +1144,7 @@ def test_custom_logits_processor(self): # output_sequences = bart_model.generate(inputs_embeds=inputs_embeds) # self.assertEqual(output_sequences.shape, (1, 5)) - + @slow def test_encoder_decoder_generate_attention_mask(self): articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"] bart_tokenizer = BartTokenizer.from_pretrained("bart-base") @@ -1016,7 +1162,7 @@ def test_encoder_decoder_generate_attention_mask(self): diff = (batched_out - out).abs() - self.assertTrue(diff.numpy() < 1e-6) + self.assertTrue((diff.numpy() < 1e-6).any()) class GenerationUtilsTestCase(unittest.TestCase):