Skip to content

Commit 383210a

Browse files
authored
[Prompt] PromptModelForSequenceClassification and unit tests (#4021)
* split models off * ready for PR * allow criterion = None * add efl test * add efl test * fix test * add to unittests * debugging * wip * ready for review * address comments
1 parent 041a4be commit 383210a

File tree

4 files changed

+253
-90
lines changed

4 files changed

+253
-90
lines changed

paddlenlp/prompt/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .template import *
16-
from .verbalizer import *
1715
from .prompt_args import *
18-
from .prompt_trainer import *
16+
from .prompt_model import *
1917
from .prompt_tokenizer import *
18+
from .prompt_trainer import *
2019
from .prompt_utils import *
20+
from .template import *
21+
from .verbalizer import *

paddlenlp/prompt/prompt_model.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Any, Dict, Optional
17+
18+
import paddle
19+
20+
from ..transformers.model_outputs import MaskedLMOutput, SequenceClassifierOutput
21+
from .prompt_utils import signature
22+
from .template import Template
23+
from .verbalizer import Verbalizer
24+
25+
26+
class PromptModelForSequenceClassification(paddle.nn.Layer):
27+
"""
28+
PromptModel for classification tasks.
29+
"""
30+
31+
def __init__(
32+
self,
33+
model: paddle.nn.Layer,
34+
template: Template,
35+
verbalizer: Optional[Verbalizer] = None,
36+
freeze_plm: bool = False,
37+
freeze_dropout: bool = False,
38+
):
39+
super(PromptModelForSequenceClassification, self).__init__()
40+
self.plm = model
41+
self.template = template
42+
self.verbalizer = verbalizer
43+
self.freeze_plm = freeze_plm
44+
self.freeze_dropout = freeze_dropout
45+
if self.freeze_plm:
46+
for param in self.plm.parameters():
47+
param.stop_gradient = True
48+
if self.freeze_dropout:
49+
self.plm.eval()
50+
self.forward_keys = signature(self.plm.forward)
51+
self._mask_token_id = self.template.tokenizer.mask_token_id
52+
self._pad_token_id = self.template.tokenizer.pad_token_id
53+
54+
def forward(
55+
self,
56+
input_ids: paddle.Tensor,
57+
token_type_ids: Optional[paddle.Tensor] = None,
58+
position_ids: Optional[paddle.Tensor] = None,
59+
attention_mask: Optional[paddle.Tensor] = None,
60+
labels: Optional[paddle.Tensor] = None,
61+
masked_positions: Optional[paddle.Tensor] = None,
62+
soft_token_ids: Optional[paddle.Tensor] = None,
63+
encoder_ids: Optional[paddle.Tensor] = None,
64+
return_dict: Optional[bool] = None,
65+
**kwargs: Dict[str, Any]
66+
):
67+
return_dict = return_dict if return_dict is not None else False
68+
input_dict = {
69+
"input_ids": input_ids,
70+
"token_type_ids": token_type_ids,
71+
"position_ids": position_ids,
72+
"masked_positions": masked_positions,
73+
"soft_token_ids": soft_token_ids,
74+
"attention_mask": attention_mask,
75+
"encoder_ids": encoder_ids,
76+
}
77+
input_dict = self.template.process_batch(input_dict)
78+
model_inputs = {k: input_dict[k] for k in input_dict if k in self.forward_keys}
79+
if "masked_positions" in model_inputs:
80+
model_inputs.pop("masked_positions")
81+
model_outputs = self.plm(**model_inputs, return_dict=True)
82+
if isinstance(model_outputs, MaskedLMOutput):
83+
if self.verbalizer is not None:
84+
logits = self.verbalizer.process_outputs(model_outputs.logits, input_dict["masked_positions"])
85+
num_labels = len(self.verbalizer.label_words)
86+
else:
87+
raise Exception("Verbalizer is required when model uses the MaskedLM head")
88+
elif isinstance(model_outputs, SequenceClassifierOutput):
89+
logits = model_outputs.logits
90+
num_labels = self.plm.num_classes if self.plm.num_classes is not None else self.plm.num_labels
91+
else:
92+
raise Exception(f"Model type not support yet: {type(model_outputs)}")
93+
94+
loss = None
95+
if labels is not None:
96+
if num_labels == 1:
97+
loss_fct = paddle.nn.MSELoss()
98+
loss = loss_fct(logits, labels)
99+
elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32:
100+
loss_fct = paddle.nn.CrossEntropyLoss()
101+
loss = loss_fct(logits.reshape((-1, num_labels)), labels.reshape((-1,)))
102+
else:
103+
loss_fct = paddle.nn.BCEWithLogitsLoss()
104+
loss = loss_fct(logits, labels)
105+
106+
if not return_dict:
107+
output = (logits, model_outputs.logits)
108+
return ((loss,) + output) if loss is not None else output
109+
return SequenceClassifierOutput(
110+
loss=loss,
111+
logits=logits,
112+
hidden_states=model_outputs.logits,
113+
)
114+
115+
def prompt_parameters(self):
116+
"""
117+
Get the parameters of template and verbalizer.
118+
"""
119+
params = [p for p in self.template.parameters()]
120+
if self.verbalizer is not None:
121+
params += [p for p in self.verbalizer.parameters()]
122+
return params

paddlenlp/prompt/prompt_trainer.py

Lines changed: 21 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1615
import os
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1717

1818
import paddle
1919
import paddle.nn as nn
2020
import paddle.nn.functional as F
2121

22+
from ..data import DataCollator
2223
from ..datasets import MapDataset
23-
from ..utils.log import logger
24+
from ..losses import RDropLoss
2425
from ..trainer import Trainer, TrainerCallback
2526
from ..trainer.trainer_utils import EvalPrediction, get_scheduler
26-
from ..data import DataCollator
27-
from ..losses import RDropLoss
2827
from ..transformers import PretrainedTokenizer, export_model
29-
28+
from ..utils.log import logger
29+
from .prompt_args import PromptTuningArguments
30+
from .prompt_utils import PromptDataCollatorWithPadding
3031
from .template import AutoTemplate
3132
from .verbalizer import SoftVerbalizer
32-
from .prompt_utils import signature, PromptDataCollatorWithPadding
33-
from .prompt_args import PromptTuningArguments
34-
35-
__all__ = ["PromptTrainer", "PromptModelForSequenceClassification"]
3633

3734

3835
class PromptTrainer(Trainer):
@@ -43,10 +40,10 @@ class PromptTrainer(Trainer):
4340

4441
def __init__(
4542
self,
46-
model: Union[nn.Layer],
43+
model: nn.Layer,
4744
tokenizer: PretrainedTokenizer,
48-
criterion: Union[nn.Layer],
49-
args: PromptTuningArguments = None,
45+
criterion: Optional[nn.Layer] = None,
46+
args: Optional[PromptTuningArguments] = None,
5047
data_collator: Optional[DataCollator] = None,
5148
train_dataset: Optional[MapDataset] = None,
5249
eval_dataset: Optional[MapDataset] = None,
@@ -64,6 +61,9 @@ def __init__(
6461
if data_collator is None:
6562
data_collator = PromptDataCollatorWithPadding(tokenizer, padding=True, return_tensors="pd")
6663

64+
if criterion is None and (args.use_rgl or args.use_rdrop):
65+
raise Exception("'To use 'use_rgl', 'use_rdrop', 'criterion' must be specified")
66+
6767
super(PromptTrainer, self).__init__(
6868
model=model,
6969
criterion=criterion,
@@ -175,7 +175,6 @@ def create_optimizer(self, lr_scheduler=None):
175175
decay_parameters = [
176176
p.name for n, p in self._get_model().named_parameters() if not any(nd in n for nd in ["bias", "norm"])
177177
]
178-
apply_decay_param_fun = lambda x: x in decay_parameters
179178

180179
if len(plm_parameters) > 0:
181180
ppt_lr = self.args.ppt_learning_rate / self.args.learning_rate
@@ -210,7 +209,7 @@ def create_optimizer(self, lr_scheduler=None):
210209

211210
self.optimizer = optim_cls(
212211
learning_rate=lr,
213-
apply_decay_param_fun=apply_decay_param_fun,
212+
apply_decay_param_fun=lambda x: x in decay_parameters,
214213
parameters=params,
215214
weight_decay=self.args.weight_decay,
216215
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm),
@@ -228,21 +227,22 @@ def compute_loss(self, model, inputs, return_outputs=False):
228227
labels = inputs["labels"]
229228

230229
input_dict = inputs.copy()
231-
input_dict["return_hidden_states"] = True
232-
outputs, hidden_states = model(**input_dict)
233230

234231
if self.criterion is not None:
235-
loss = self.criterion(outputs, labels)
232+
# pop labels to move loss computation out of the model
233+
input_dict.pop("labels")
234+
logits, hidden_states = model(**input_dict)
235+
loss = self.criterion(logits, labels)
236236

237237
if self.args.use_rdrop:
238-
loss = self._compute_rdrop_loss(model, input_dict, outputs, loss)
238+
loss = self._compute_rdrop_loss(model, input_dict, logits, loss)
239239

240240
if self.args.use_rgl:
241241
loss += self._compute_rgl_loss(hidden_states, labels)
242+
else:
243+
loss, logits, _ = model(**input_dict)
242244

243-
outputs = (loss, outputs)
244-
245-
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
245+
outputs = (loss, logits)
246246

247247
return (loss, outputs) if return_outputs else loss
248248

@@ -294,69 +294,3 @@ def export_model(self, export_path, input_spec, export_type="paddle"):
294294
if self.verbalizer is not None:
295295
self.verbalizer.save(export_path)
296296
export_model(self.model, input_spec, export_path, export_type)
297-
298-
299-
class PromptModelForSequenceClassification(nn.Layer):
300-
"""
301-
PromptModel for classification tasks.
302-
"""
303-
304-
def __init__(self, model, template, verbalizer=None, freeze_plm: bool = False, freeze_dropout: bool = False):
305-
super(PromptModelForSequenceClassification, self).__init__()
306-
self.plm = model
307-
self.template = template
308-
self.verbalizer = verbalizer
309-
self.freeze_plm = freeze_plm
310-
self.freeze_dropout = freeze_dropout
311-
if self.freeze_plm:
312-
for param in self.plm.parameters():
313-
param.stop_gradient = True
314-
if self.freeze_dropout:
315-
self.plm.eval()
316-
self.forward_keys = signature(self.plm.forward)
317-
self._mask_token_id = self.template.tokenizer.mask_token_id
318-
self._pad_token_id = self.template.tokenizer.pad_token_id
319-
320-
def forward(
321-
self,
322-
input_ids,
323-
token_type_ids=None,
324-
position_ids=None,
325-
attention_mask=None,
326-
masked_positions=None,
327-
soft_token_ids=None,
328-
encoder_ids=None,
329-
**kwargs
330-
):
331-
input_dict = {
332-
"input_ids": input_ids,
333-
"token_type_ids": token_type_ids,
334-
"position_ids": position_ids,
335-
"masked_positions": masked_positions,
336-
"soft_token_ids": soft_token_ids,
337-
"attention_mask": attention_mask,
338-
"encoder_ids": encoder_ids,
339-
}
340-
input_dict = self.template.process_batch(input_dict)
341-
model_inputs = {k: input_dict[k] for k in input_dict if k in self.forward_keys}
342-
if "masked_positions" in model_inputs:
343-
model_inputs.pop("masked_positions")
344-
outputs = self.plm(**model_inputs)
345-
if self.verbalizer is not None:
346-
label_outputs = self.verbalizer.process_outputs(outputs, input_dict["masked_positions"])
347-
else:
348-
label_outputs = outputs
349-
350-
if kwargs.pop("return_hidden_states", False):
351-
return label_outputs, outputs
352-
else:
353-
return label_outputs
354-
355-
def prompt_parameters(self):
356-
"""
357-
Get the parameters of template and verbalizer.
358-
"""
359-
params = [p for p in self.template.parameters()]
360-
if self.verbalizer is not None:
361-
params += [p for p in self.verbalizer.parameters()]
362-
return params

0 commit comments

Comments
 (0)