diff --git a/examples/few_shot/prefix-tuning/run_train.py b/examples/few_shot/prefix-tuning/run_train.py new file mode 100644 index 000000000000..b22ba81af976 --- /dev/null +++ b/examples/few_shot/prefix-tuning/run_train.py @@ -0,0 +1,145 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from functools import partial +from typing import Optional + +import paddle +from utils import PromptTrainerForGeneration, compute_metrics + +from paddlenlp.datasets import load_dataset +from paddlenlp.prompt import ( + PrefixTemplate, + PromptModelForGeneration, + PromptTuningArguments, +) +from paddlenlp.trainer import PdArgumentParser +from paddlenlp.transformers import AutoTokenizer, GPTLMHeadModel +from paddlenlp.utils.log import logger + + +@dataclass +class DataArguments: + prompt: str = field( + default="{'prefix':'根据回答和原文得到问题', 'length':50}{'text':'text'}{'sep'}{'text':'labels', 'token_type': 1, 'truncate': False}", + metadata={"help": "Add prompt.'prefix'、'text' variable and 'text':'labels' immutable."}, + ) + task_name: str = field(default="dureader_qg", metadata={"help": "The name of task."}) + + +@dataclass +class ModelArguments: + model_name_or_path: str = field( + default="gpt-cpm-small-cn-distill", + metadata={"help": "Build-in pretrained model name or the path to local model."}, + ) + export_type: str = field(default="paddle", metadata={"help": "The type to export. Support `paddle` and `onnx`."}) + dropout: float = field(default=0.1, metadata={"help": "The dropout used for pretrained model."}) + predict_with_generate: Optional[bool] = field( + default=True, + metadata={"help": ("Whether to generate in predcit.")}, + ) + num_beams: Optional[int] = field( + default=2, + metadata={"help": ("The number of beams to use in beam search.")}, + ) + max_target_length: Optional[int] = field( + default=16, + metadata={ + "help": ( + "The maximum total sequence length for target text after " + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." + "during ``evaluate`` and ``predict``." + ) + }, + ) + + +def main(): + # Parse the arguments. + parser = PdArgumentParser((ModelArguments, DataArguments, PromptTuningArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + training_args.generation_max_length = model_args.max_target_length + training_args.predict_with_generate = model_args.predict_with_generate + training_args.generation_num_beams = model_args.num_beams + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + paddle.set_device(training_args.device) + + # Load the pretrained language model. + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + tokenizer.pad_token = "" + tokenizer.sep_token = "" + tokenizer.add_tokens("[Space]", special_tokens=True) + model = GPTLMHeadModel.from_pretrained( + model_args.model_name_or_path, + hidden_dropout_prob=model_args.dropout, + attention_probs_dropout_prob=model_args.dropout, + ) + + # Define template for preprocess. + template = PrefixTemplate(data_args.prompt, tokenizer, training_args.max_seq_length, model) + logger.info("Using template: {}".format(template.prompt)) + + # Load datasets. + train_ds, dev_ds = load_dataset(data_args.task_name, splits=["train", "dev"]) + + def convert_label_keyword(input_dict): + if "text" not in input_dict: + input_dict["text"] = ("答案:" + input_dict.pop("title") + "," + "上下文:" + input_dict.pop("source"))[:400] + if "labels" not in input_dict: + input_dict["labels"] = "在已知答案的前提下,问题:" + input_dict.pop("target")[:20] + return input_dict + + train_ds.map(convert_label_keyword) + dev_ds.map(convert_label_keyword) + + # Initialize the prompt model with the above variables. + prompt_model = PromptModelForGeneration( + model, + template, + freeze_plm=training_args.freeze_plm, + freeze_dropout=training_args.freeze_dropout, + ) + + dev_compute_metrics = partial(compute_metrics, tokenizer=tokenizer) + trainer = PromptTrainerForGeneration( + model=prompt_model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_ds, + eval_dataset=dev_ds, + callbacks=None, + compute_metrics=dev_compute_metrics, + ) + + # Traininig. + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_eval: + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + +if __name__ == "__main__": + main() diff --git a/examples/few_shot/prefix-tuning/utils.py b/examples/few_shot/prefix-tuning/utils.py new file mode 100644 index 000000000000..545ec4144bb7 --- /dev/null +++ b/examples/few_shot/prefix-tuning/utils.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from rouge import Rouge + +from paddlenlp.metrics import BLEU +from paddlenlp.prompt import PromptTrainer + + +# Define the metric function. +def compute_metrics(eval_preds, tokenizer): + + all_preds = [] + all_labels = [] + labels = eval_preds.label_ids + preds = eval_preds.predictions + all_preds.extend(tokenizer.convert_ids_to_string(pred.tolist()) for pred in preds) + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + all_labels.extend(tokenizer.convert_ids_to_string(label.tolist()) for label in labels) + + assert len(all_preds) == len(all_labels), ( + "The length of pred_responses should be equal to the length of " + "target_responses. But received {} and {}.".format(len(all_preds), len(all_labels)) + ) + rouge = Rouge() + bleu4 = BLEU(n_size=4) + scores = [] + for pred, target in zip(all_preds, all_labels): + try: + score = rouge.get_scores(" ".join(pred), " ".join(target)) + scores.append([score[0]["rouge-1"]["f"], score[0]["rouge-2"]["f"], score[0]["rouge-l"]["f"]]) + except ValueError: + scores.append([0, 0, 0]) + bleu4.add_inst(pred, [target]) + rouge1 = np.mean([i[0] for i in scores]) + rouge2 = np.mean([i[1] for i in scores]) + rougel = np.mean([i[2] for i in scores]) + print("\n" + "*" * 15) + print("The auto evaluation result is:") + print("rouge-1:", round(rouge1, 4)) + print("rouge-2:", round(rouge2, 4)) + print("rouge-L:", round(rougel, 4)) + print("BLEU-4:", round(bleu4.score(), 4)) + return {"rougel": rougel} + + +class PromptTrainerForGeneration(PromptTrainer): + def __init__( + self, + model, + tokenizer, + criterion=None, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + compute_metrics=None, + callbacks=None, + optimizers=(None, None), + ): + super(PromptTrainerForGeneration, self).__init__( + model=model, + criterion=criterion, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + ) + self.verbalizer = None + + def prediction_step( + self, + model, + inputs, + prediction_loss_only, + ignore_keys=None, + ): + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Layer`): + The model to evaluate. + inputs (`Dict[str, Union[paddle.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + + Return: + Tuple[Optional[float], Optional[paddle.Tensor], Optional[paddle.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + labels = inputs["labels"] + # inputs = self._prepare_inputs(inputs) + + max_length = 32 + generated_tokens = self.model.generate( + model_kwargs=inputs, + ) + + # different from hf returns: tuple[Tensor]: It is a tuple contains two elements: ids and scores. + if isinstance(generated_tokens, tuple): + generated_tokens = generated_tokens[0] + # in case the batch is shorter than max length, the output should be padded + if max_length is not None and generated_tokens.shape[-1] < max_length: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, max_length) + + with paddle.no_grad(): + if has_labels: + with self.autocast_smart_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return (loss, None, None) + + return (loss, generated_tokens, labels) + + def _pad_tensors_to_max_len(self, tensor, max_length): + if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): + # If PAD token is not defined at least EOS token has to be defined + pad_token_id = ( + self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id + ) + else: + if self.tokenizer.pad_token_id is not None: + pad_token_id = self.tokenizer.pad_token_id + else: + raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") + # paddle.ones need to support device args. + padded_tensor = pad_token_id * paddle.ones((tensor.shape[0], max_length), dtype=tensor.dtype) + padded_tensor[:, : tensor.shape[-1]] = tensor + return padded_tensor diff --git a/paddlenlp/prompt/prompt_model.py b/paddlenlp/prompt/prompt_model.py index 496662d5c7f6..2707578a0309 100644 --- a/paddlenlp/prompt/prompt_model.py +++ b/paddlenlp/prompt/prompt_model.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Dict, Optional import paddle from paddle.static import InputSpec from ..transformers.model_outputs import ( + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, SequenceClassifierOutput, @@ -160,3 +160,138 @@ def get_input_spec(self): if "encoder" in template_keywords: input_spec.append(InputSpec(shape=[None, None], dtype="int64", name="encoder_ids")) return input_spec + + +class PromptModelForGeneration(paddle.nn.Layer): + """ + PromptModel for classification tasks. + """ + + def __init__( + self, + model: paddle.nn.Layer, + template: Template, + freeze_plm: bool = False, + freeze_dropout: bool = False, + ): + super(PromptModelForGeneration, self).__init__() + self.plm = model + self.template = template + self.freeze_plm = freeze_plm + self.freeze_dropout = freeze_dropout + if self.freeze_plm: + for param in self.plm.parameters(): + param.stop_gradient = True + if self.freeze_dropout: + self.plm.eval() + self.forward_keys = signature(self.plm.forward) + self._mask_token_id = self.template.tokenizer.mask_token_id + self._pad_token_id = self.template.tokenizer.pad_token_id + if not isinstance(self.template, PrefixTemplate): + raise TypeError(f"{self.__class__.__name__} is not compatible with {self.template.__class__.__name__} ") + self.plm = self.template.process_model(self.plm) + self.forward_keys.append("past_key_values") + self.base_model_prepare_inputs_for_generation = self.plm.prepare_inputs_for_generation + + def forward( + self, + input_ids: paddle.Tensor, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + soft_token_ids: Optional[paddle.Tensor] = None, + encoder_ids: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + return_dict: Optional[bool] = None, + **kwargs: Dict[str, Any] + ): + return_dict = return_dict if return_dict is not None else False + if soft_token_ids is None: + outputs = self.plm(input_ids) + return outputs + + return_hidden_states = kwargs.get("return_hidden_states", False) + input_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "position_ids": position_ids, + "soft_token_ids": soft_token_ids, + "encoder_ids": encoder_ids, + "labels": labels, + **kwargs, + } + input_dict = self.template.process_batch(input_dict) + input_dict = {**input_dict, **kwargs} + model_inputs = {k: input_dict[k] for k in input_dict if k in self.forward_keys} + if "cache" in self.forward_keys: + model_inputs["cache"] = [] + for i in range(len(model_inputs["past_key_values"])): + from paddlenlp.transformers.gpt.modeling import MultiHeadAttention + + model_inputs["cache"].append( + MultiHeadAttention.Cache( + k=model_inputs["past_key_values"][i][0], v=model_inputs["past_key_values"][i][1] + ) + ) + model_inputs.pop("past_key_values") + model_inputs.pop("labels") + model_outputs = self.plm(**model_inputs, return_dict=True, use_cache=True) + logits = model_outputs.logits + + loss = None + if labels is not None: + shift_labels = labels[..., 1:] + shift_logits = logits[..., : shift_labels.shape[1], :] + loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=-100, reduction="mean") + loss = loss_fct(shift_logits.reshape((-1, shift_logits.shape[-1])), shift_labels.reshape((-1,))).reshape( + [-1] + ) + + if not return_dict: + output = (logits,) + if return_hidden_states: + output = output + (model_outputs.logits,) + if loss is not None: + return (loss,) + output + if isinstance(output, (list, tuple)) and len(output) == 1: + output = output[0] + return output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.logits, + ) + + def generate(self, model_kwargs, **kwargs): + self.plm.prepare_inputs_for_generation = self.prepare_inputs_for_generation + generated_tokens = self.plm.generate(**model_kwargs, **kwargs) + return generated_tokens + + def prepare_inputs_for_generation(self, input_ids, use_cache=False, cache=None, **kwargs): + model_kwargs = self.base_model_prepare_inputs_for_generation(input_ids, cache=None, **kwargs) + model_kwargs["soft_token_ids"] = kwargs.get("soft_token_ids", None) + model_kwargs["token_type_ids"] = kwargs.get("token_type_ids", None) + model_kwargs["encoder_ids"] = kwargs.get("encoder_ids", None) + len_dif = len(model_kwargs["token_type_ids"][0]) - len(model_kwargs["soft_token_ids"][0]) + for _ in range(len_dif): + model_kwargs["soft_token_ids"] = paddle.concat( + [model_kwargs["soft_token_ids"], paddle.to_tensor([[0]])], axis=1 + ) + input_dict = self.template.process_batch(model_kwargs) + model_inputs = {k: input_dict[k] for k in input_dict if k in self.forward_keys} + if "cache" in self.forward_keys: + model_inputs["cache"] = [] + for i in range(len(model_inputs["past_key_values"])): + from paddlenlp.transformers.gpt.modeling import MultiHeadAttention + + model_inputs["cache"].append( + MultiHeadAttention.Cache( + k=model_inputs["past_key_values"][i][0], v=model_inputs["past_key_values"][i][1] + ) + ) + model_inputs.pop("past_key_values") + model_inputs["use_cache"] = True + model_inputs["return_dict"] = True + + return model_inputs diff --git a/paddlenlp/prompt/prompt_tokenizer.py b/paddlenlp/prompt/prompt_tokenizer.py index 8e41162c5ab6..2986b80409c5 100644 --- a/paddlenlp/prompt/prompt_tokenizer.py +++ b/paddlenlp/prompt/prompt_tokenizer.py @@ -43,6 +43,15 @@ def __call__(self, inputs: List[Dict[str, Any]]): # Create input_ids. soft_token_ids = part.get("soft_tokens", None) if soft_token_ids is None or len(soft_token_ids) == 1 and soft_token_ids[0] == 0: + if "generator_labels" in part: + # import pdb; pdb.set_trace() + encoded_inputs["labels"].append( + self.tokenizer.encode( + part["generator_labels"], add_special_tokens=False, return_token_type_ids=False + )["input_ids"] + ) + inputs.remove(part) + continue orig_input_ids.append( self.tokenizer.encode(part["text"], add_special_tokens=False, return_token_type_ids=False)[ "input_ids" @@ -127,7 +136,7 @@ def _create_max_lengths_from_do_truncate(self, part_text: List[str], part_do_tru Create the max sequence length of each part, where the longest part is truncated first. """ text_length = sum([len(x) for x in part_text]) - num_special_token = self.tokenizer.num_special_tokens_to_add() + num_special_token = self.tokenizer.num_special_tokens_to_add(pair=False) max_length = self.max_length - num_special_token if text_length <= max_length: return [None] * len(part_text) diff --git a/paddlenlp/prompt/template.py b/paddlenlp/prompt/template.py index 2c6c5ef52a5c..266ea7132cf5 100644 --- a/paddlenlp/prompt/template.py +++ b/paddlenlp/prompt/template.py @@ -251,7 +251,8 @@ def encode(self, example: Dict[str, Any]): inputs = [] for value in list(zip(*input_values)): inputs.append(dict(zip(input_names, value))) - + if "labels" in example and isinstance(example["labels"], str): + inputs.append({"generator_labels": example["labels"], "do_truncate": False}) input_dict = self.prompt_tokenizer(inputs) unused_example = {k: v for k, v in example.items() if k not in self.example_keys} diff --git a/paddlenlp/transformers/gpt/tokenizer.py b/paddlenlp/transformers/gpt/tokenizer.py index 63ac00d1281f..99f9ade54ef3 100644 --- a/paddlenlp/transformers/gpt/tokenizer.py +++ b/paddlenlp/transformers/gpt/tokenizer.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil from functools import lru_cache -import json import jieba -import shutil import sentencepiece as spm from paddle.utils import try_import -from .. import PretrainedTokenizer, AddedToken +from .. import AddedToken, PretrainedTokenizer __all__ = [ "GPTTokenizer", @@ -200,6 +200,7 @@ def convert_tokens_to_ids(self, tokens): return [self._convert_token_to_id(token) for token in tokens] ''' + ''' def convert_ids_to_tokens(self, ids): """ Converts a single index or a sequence of indices to a token or a @@ -227,6 +228,7 @@ def convert_ids_to_tokens(self, ids): return self._convert_id_to_token(ids) tokens = [self._convert_id_to_token(_id) for _id in ids] return tokens + ''' @property def vocab_size(self): diff --git a/tests/prompt/test_prompt_model.py b/tests/prompt/test_prompt_model.py index a7d853956dda..2d8cd6b4d863 100644 --- a/tests/prompt/test_prompt_model.py +++ b/tests/prompt/test_prompt_model.py @@ -17,6 +17,7 @@ from paddlenlp.prompt import ( AutoTemplate, PromptDataCollatorWithPadding, + PromptModelForGeneration, PromptModelForSequenceClassification, SoftVerbalizer, ) @@ -24,6 +25,7 @@ AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoTokenizer, + GPTLMHeadModel, ) @@ -116,5 +118,50 @@ def test_efl_with_labels(self): self.assertEqual(model_outputs.hidden_states.shape[0], len(examples)) +class PromptModelTestForGeneration(unittest.TestCase): + def test_generation_with_labels(self): + self.tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-gpt") + self.tokenizer.pad_token = "" + self.tokenizer.sep_token = "" + self.tokenizer.add_tokens("[Space]", special_tokens=True) + self.model = GPTLMHeadModel.from_pretrained("__internal_testing__/tiny-random-gpt") + + self.template = AutoTemplate.create_from( + prompt="{'prefix':'文本摘要', 'encoder': 'mlp'}{'text':'text'}{'sep'}{'text':'labels', 'token_type': 1}", + tokenizer=self.tokenizer, + max_length=512, + model=self.model, + ) + + self.data_collator = PromptDataCollatorWithPadding(self.tokenizer, padding=True, return_tensors="pd") + self.prompt_model = PromptModelForGeneration(self.model, self.template) + examples = [ + { + "text": "日前,方舟子发文直指林志颖旗下爱碧丽推销假保健品,引起哗然。调查发现,爱碧丽没有自己的生产加工厂。其胶原蛋白饮品无核心研发,全部代工生产。号称有“逆生长”功效的爱碧丽“梦幻奇迹限量组”售价高达1080元,实际成本仅为每瓶4元!", + "labels": "林志颖公司疑涉虚假营销无厂房无研发", + "id": 0, + }, + { + "text": "韩方应对路径可以概括为:企业道歉担责;政府公正不护短;民间祈福关怀。他们深知形象的重要,竭力呵护企业品牌和国家形象。正如有评论,韩国“政府+企业+民众”三位一体式呵护韩国国家形象的“苦心经营”,的确有值得我们借鉴之处。", + "labels": "从韩亚航空事故看其应对路径", + "id": 1, + }, + ] + encoded_examples = [self.template(i) for i in examples] + loss, logits, hidden_states = self.prompt_model( + **self.data_collator(encoded_examples), return_hidden_states=True + ) + self.assertIsNotNone(loss) + self.assertEqual(logits.shape[0], len(examples)) + self.assertEqual(hidden_states.shape[0], len(examples)) + + model_outputs = self.prompt_model( + **self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True + ) + self.assertIsNotNone(model_outputs.loss) + self.assertEqual(model_outputs.logits.shape[0], len(examples)) + self.assertEqual(model_outputs.hidden_states.shape[0], len(examples)) + + if __name__ == "__main__": unittest.main()