Skip to content

Commit c742b16

Browse files
authored
[prompt] Fix bug in Template (#4456)
* [trainer] fix unknown variable bug in prompt trainer * [prompt] fix bug in template
1 parent b4697f0 commit c742b16

File tree

4 files changed

+175
-28
lines changed

4 files changed

+175
-28
lines changed

paddlenlp/prompt/prompt_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
SequenceClassifierOutput,
2525
)
2626
from .prompt_utils import signature
27-
from .template import Template
27+
from .template import PrefixTemplate, Template
2828
from .verbalizer import Verbalizer
2929

3030

@@ -55,6 +55,9 @@ def __init__(
5555
self.forward_keys = signature(self.plm.forward)
5656
self._mask_token_id = self.template.tokenizer.mask_token_id
5757
self._pad_token_id = self.template.tokenizer.pad_token_id
58+
if isinstance(self.template, PrefixTemplate):
59+
self.plm = self.template.process_model(self.plm)
60+
self.forward_keys.append("past_key_values")
5861

5962
def forward(
6063
self,
@@ -82,6 +85,7 @@ def forward(
8285
**kwargs,
8386
}
8487
input_dict = self.template.process_batch(input_dict)
88+
input_dict = {**input_dict, **kwargs}
8589
model_inputs = {k: input_dict[k] for k in input_dict if k in self.forward_keys}
8690
if "masked_positions" in model_inputs:
8791
model_inputs.pop("masked_positions")

paddlenlp/prompt/prompt_trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,18 @@ def create_optimizer(self, lr_scheduler=None):
198198
else:
199199
params = plm_parameters
200200
else:
201-
args = self.init_num_steps(self.args, len(self.train_dataset))
201+
if self.args.max_steps > 0:
202+
max_steps = self.args.max_steps
203+
else:
204+
raise ValueError("Please use `max_steps` to set the maximum training steps.")
202205
warmup = (
203-
args.warmup_steps if args.warmup_steps > 0 else int(args.warmup_ratio * args.num_training_steps)
206+
self.args.warmup_steps if self.args.warmup_steps > 0 else int(self.args.warmup_ratio * max_steps)
204207
)
205208
self.lr_scheduler = get_scheduler(
206-
args.lr_scheduler_type,
209+
self.args.lr_scheduler_type,
207210
learning_rate=self.args.ppt_learning_rate,
208211
num_warmup_steps=warmup,
209-
num_training_steps=args.num_training_steps,
212+
num_training_steps=max_steps,
210213
)
211214
lr = self.lr_scheduler
212215
params = ppt_parameters

paddlenlp/prompt/prompt_utils.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
import inspect
2020
from dataclasses import dataclass
21-
from typing import Any, Dict, List, Optional, Union
21+
from typing import Any, Dict, List, Optional, Tuple, Union
2222

2323
import numpy as np
2424
import paddle
25+
from paddle import Tensor
2526

27+
from ..transformers.model_outputs import MaskedLMOutput, SequenceClassifierOutput
2628
from ..transformers.tokenizer_utils_base import PaddingStrategy, PretrainedTokenizerBase
2729

2830

@@ -114,3 +116,95 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
114116
continue
115117
batch[key] = self._convert_to_tensors(values)
116118
return batch
119+
120+
121+
def sequence_classification_forward_with_past_key_values(
122+
self,
123+
input_ids: Optional[Tensor] = None,
124+
token_type_ids: Optional[Tensor] = None,
125+
position_ids: Optional[Tensor] = None,
126+
attention_mask: Optional[Tensor] = None,
127+
inputs_embeds: Optional[Tensor] = None,
128+
labels: Optional[Tensor] = None,
129+
output_hidden_states: Optional[bool] = None,
130+
output_attentions: Optional[bool] = None,
131+
return_dict: Optional[bool] = None,
132+
past_key_values: Optional[Tuple[Tuple[Tensor]]] = None,
133+
):
134+
outputs = self.ernie(
135+
input_ids,
136+
token_type_ids=token_type_ids,
137+
position_ids=position_ids,
138+
attention_mask=attention_mask,
139+
inputs_embeds=inputs_embeds,
140+
past_key_values=past_key_values,
141+
output_attentions=output_attentions,
142+
output_hidden_states=output_hidden_states,
143+
return_dict=True,
144+
)
145+
pooled_output = outputs[1]
146+
147+
pooled_output = self.dropout(pooled_output)
148+
logits = self.classifier(pooled_output)
149+
150+
loss = None
151+
if labels is not None:
152+
if self.num_labels == 1:
153+
loss_fct = paddle.nn.MSELoss()
154+
loss = loss_fct(logits, labels)
155+
elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32:
156+
loss_fct = paddle.nn.CrossEntropyLoss()
157+
loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1,)))
158+
else:
159+
loss_fct = paddle.nn.BCEWithLogitsLoss()
160+
loss = loss_fct(logits, labels)
161+
162+
return SequenceClassifierOutput(
163+
loss=loss,
164+
logits=logits,
165+
hidden_states=outputs.hidden_states,
166+
attentions=outputs.attentions,
167+
)
168+
169+
170+
def masked_lm_forward_with_past_key_values(
171+
self,
172+
input_ids: Optional[Tensor] = None,
173+
token_type_ids: Optional[Tensor] = None,
174+
position_ids: Optional[Tensor] = None,
175+
attention_mask: Optional[Tensor] = None,
176+
masked_positions: Optional[Tensor] = None,
177+
inputs_embeds: Optional[Tensor] = None,
178+
labels: Optional[Tensor] = None,
179+
output_hidden_states: Optional[bool] = None,
180+
output_attentions: Optional[bool] = None,
181+
return_dict: Optional[bool] = None,
182+
past_key_values: Optional[Tuple[Tuple[Tensor]]] = None,
183+
):
184+
outputs = self.ernie(
185+
input_ids,
186+
token_type_ids=token_type_ids,
187+
position_ids=position_ids,
188+
attention_mask=attention_mask,
189+
inputs_embeds=inputs_embeds,
190+
past_key_values=past_key_values,
191+
output_attentions=output_attentions,
192+
output_hidden_states=output_hidden_states,
193+
return_dict=True,
194+
)
195+
sequence_output = outputs[0]
196+
prediction_scores = self.cls(sequence_output, masked_positions=masked_positions)
197+
198+
masked_lm_loss = None
199+
if labels is not None:
200+
loss_fct = paddle.nn.CrossEntropyLoss()
201+
masked_lm_loss = loss_fct(
202+
prediction_scores.reshape((-1, paddle.shape(prediction_scores)[-1])), labels.reshape((-1,))
203+
)
204+
205+
return MaskedLMOutput(
206+
loss=masked_lm_loss,
207+
logits=prediction_scores,
208+
hidden_states=outputs.hidden_states,
209+
attentions=outputs.attentions,
210+
)

paddlenlp/prompt/template.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import re
2222
import traceback
2323
from abc import abstractmethod
24+
from functools import partial
2425
from typing import Any, Dict, List, Optional
2526

2627
import numpy as np
@@ -32,6 +33,10 @@
3233
from paddlenlp.utils.log import logger
3334

3435
from .prompt_tokenizer import MLMPromptTokenizer
36+
from .prompt_utils import (
37+
masked_lm_forward_with_past_key_values,
38+
sequence_classification_forward_with_past_key_values,
39+
)
3540

3641
__all__ = ["Template", "ManualTemplate", "SoftTemplate", "PrefixTemplate", "AutoTemplate", "UTCTemplate"]
3742

@@ -263,8 +268,10 @@ def save(self, save_path):
263268
if not os.path.exists(save_path):
264269
os.makedirs(save_path, exist_ok=True)
265270
template_config_file = os.path.join(save_path, TEMPLATE_CONFIG_FILE)
271+
template_class = self.__class__.__name__
266272
with open(template_config_file, "w", encoding="utf-8") as fp:
267-
fp.write(json.dumps(self._prompt, ensure_ascii=False))
273+
fp.write(json.dumps(self._prompt, ensure_ascii=False) + "\n")
274+
fp.write(json.dumps({"class": template_class}, ensure_ascii=False) + "\n")
268275
template_param_file = os.path.join(save_path, TEMPLATE_PARAMETER_FILE)
269276
template_state_dict = self.state_dict()
270277
if len(template_state_dict) > 0:
@@ -709,36 +716,54 @@ def parse_soft_prompt(self):
709716
raise ValueError("Keyword `prefix` should locate at the beginning of template.")
710717
part["soft"] = part["prefix"]
711718
part.pop("prefix")
719+
if "encoder" not in part:
720+
part["encoder"] = "mlp"
712721
prompt[index] = part
713722

714723
self._prompt = prompt
715724
return super(PrefixTemplate, self).parse_soft_prompt()
716725

726+
def process_model(self, model):
727+
if model.__class__.__name__.endswith("ForSequenceClassification"):
728+
model.forward = partial(sequence_classification_forward_with_past_key_values, self=model)
729+
elif model.__class__.__name__.endswith("ForMaskedLM"):
730+
model.forward = partial(masked_lm_forward_with_past_key_values, self=model)
731+
return model
732+
717733
def process_batch(self, input_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
718734
word_embeds = self.word_embeddings(input_dict["input_ids"])
735+
batch_size, _ = input_dict["soft_token_ids"].shape
736+
737+
soft_token_ids = paddle.masked_select(input_dict["soft_token_ids"], input_dict["soft_token_ids"] > 0)
738+
soft_token_ids = soft_token_ids.reshape([batch_size, -1])
739+
_, soft_len = soft_token_ids.shape
740+
741+
token_type_ids = paddle.masked_select(input_dict["token_type_ids"], input_dict["soft_token_ids"] == 0)
742+
input_dict["token_type_ids"] = token_type_ids.reshape([batch_size, -1])
743+
position_ids = paddle.masked_select(input_dict["position_ids"], input_dict["soft_token_ids"] == 0)
744+
input_dict["position_ids"] = position_ids.reshape([batch_size, -1])
745+
if "masked_position" in input_dict and input_dict["masked_positions"] is not None:
746+
input_dict["masked_positions"] = input_dict["masked_positions"] - soft_len
747+
input_dict["inputs_embeds"] = paddle.concat(
748+
[word_embeds[:, 0, :].unsqueeze(1), word_embeds[:, soft_len + 1 :, :]], axis=1
749+
)
750+
719751
if "attention_mask" not in input_dict or input_dict["attention_mask"] is None:
720752
pad_token_id = self.tokenizer.pad_token_id
721753
attention_mask = paddle.unsqueeze(
722754
(input_dict["input_ids"] == pad_token_id).astype("float32") * -1e4, axis=[1, 2]
723755
)
724756
input_dict["attention_mask"] = attention_mask
725757
input_dict["input_ids"] = None
726-
727-
batch_size, _ = input_dict["soft_token_ids"].shape
728-
soft_token_ids = paddle.masked_select(input_dict["soft_token_ids"], input_dict["soft_token_ids"] > 0)
729-
soft_token_ids = soft_token_ids.reshape([batch_size, -1])
730-
_, soft_len = soft_token_ids.shape
731-
732-
input_dict["inputs_embeds"] = word_embeds[:, soft_len:, :]
758+
input_dict.pop("soft_token_ids")
759+
input_dict.pop("encoder_ids")
733760

734761
soft_embeds = self.soft_embeddings(soft_token_ids)
735-
for encoder_id in range(1, len(self.encoder_list)):
736-
to_encode = paddle.where(input_dict["encoder_ids"] == encoder_id)
737-
encoded = self.encoder_list[encoder_id](to_encode)
738-
soft_embeds = paddle.where(input_dict["encoder_ids"] == encoder_id, encoded, soft_embeds)
762+
soft_embeds = self.encoder_list[1](soft_embeds)
739763
soft_embeds = soft_embeds.reshape(
740764
[batch_size, soft_len, self.n_layer * 2, self.n_heads, self.embed_size // self.n_heads]
741765
)
766+
742767
soft_embeds = self.dropout(soft_embeds)
743768
soft_embeds = paddle.transpose(soft_embeds, perm=[2, 0, 3, 1, 4])
744769
soft_embeds = paddle.split(soft_embeds, num_or_sections=self.n_layer)
@@ -776,6 +801,7 @@ def create_from(
776801
model: PretrainedModel = None,
777802
soft_embeddings: Tensor = None,
778803
prefix_dropout: float = 0.1,
804+
template_class: str = None,
779805
):
780806
# Default template if not defined.
781807
if prompt is None:
@@ -791,12 +817,20 @@ def create_from(
791817
if "mask" not in template_keywords:
792818
prompt = prompt + [{"mask": None}]
793819

820+
if template_class is None:
821+
if "prefix" in template_keywords:
822+
template_class = "PrefixTemplate"
823+
elif "soft" in template_keywords or "soft_id" in template_keywords:
824+
template_class = "SoftTemplate"
825+
else:
826+
template_class = "ManualTemplate"
827+
794828
# Choose Template according to template keywords.
795-
if "prefix" in template_keywords:
829+
if template_class == "PrefixTemplate":
796830
return PrefixTemplate(
797831
prompt=prompt, tokenizer=tokenizer, max_length=max_length, model=model, prefix_dropout=prefix_dropout
798832
)
799-
elif "soft" in template_keywords or "soft_id" in template_keywords:
833+
elif template_class == "SoftTemplate":
800834
word_embeddings = model.get_input_embeddings()
801835
return SoftTemplate(
802836
prompt=prompt,
@@ -805,10 +839,12 @@ def create_from(
805839
word_embeddings=word_embeddings,
806840
soft_embeddings=soft_embeddings,
807841
)
808-
elif "options" in template_keywords:
842+
elif template_class == "UTCTemplate":
809843
return UTCTemplate(tokenizer=tokenizer, max_length=max_length)
810-
else:
844+
elif template_class == "ManualTemplate":
811845
return ManualTemplate(prompt=prompt, tokenizer=tokenizer, max_length=max_length)
846+
else:
847+
raise ValueError(f"Unknown template: {template_class}.")
812848

813849
@classmethod
814850
def load_from(
@@ -818,9 +854,15 @@ def load_from(
818854
if not os.path.isfile(template_config_file):
819855
raise ValueError("{} not found under {}".format(TEMPLATE_CONFIG_FILE, data_path))
820856
with open(template_config_file, "r") as fp:
821-
prompt = json.loads(fp.readline().strip())
822-
# TODO (Huijuan): Load all configs from data_path.
823-
template = cls.create_from(prompt=prompt, tokenizer=tokenizer, max_length=max_length, model=model)
857+
config = [x.strip() for x in fp]
858+
prompt = json.loads(config[0])
859+
if len(config) > 1:
860+
template_class = json.loads(config[1])
861+
else:
862+
template_class = None # Compatible with previous versions
863+
template = cls.create_from(
864+
prompt=prompt, tokenizer=tokenizer, max_length=max_length, model=model, template_class=template_class
865+
)
824866
template_param_file = os.path.join(data_path, TEMPLATE_PARAMETER_FILE)
825867
if os.path.isfile(template_param_file):
826868
template.set_state_dict(paddle.load(template_param_file))
@@ -834,10 +876,14 @@ class UTCTemplate(Template):
834876

835877
template_special_tokens = ["text", "hard", "sep", "cls", "options"]
836878

837-
def __init__(self, tokenizer: PretrainedTokenizer, max_length: int):
879+
def __init__(self, tokenizer: PretrainedTokenizer, max_length: int, prompt: str = None):
838880
prompt = (
839-
"{'options': 'choices', 'add_omask': True, 'position': 0, 'token_type': 1}"
840-
"{'sep': None, 'token_type': 0, 'position': 0}{'text': 'text_a'}{'sep': None, 'token_type': 1}{'text': 'text_b'}"
881+
(
882+
"{'options': 'choices', 'add_omask': True, 'position': 0, 'token_type': 1}"
883+
"{'sep': None, 'token_type': 0, 'position': 0}{'text': 'text_a'}{'sep': None, 'token_type': 1}{'text': 'text_b'}"
884+
)
885+
if prompt is None
886+
else prompt
841887
)
842888
super(UTCTemplate, self).__init__(prompt, tokenizer, max_length)
843889
self.max_position_id = self.tokenizer.model_max_length - 1

0 commit comments

Comments
 (0)