12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
16
15
import os
16
+ from typing import Any , Callable , Dict , List , Optional , Tuple
17
17
18
18
import paddle
19
19
import paddle .nn as nn
20
20
import paddle .nn .functional as F
21
21
22
+ from ..data import DataCollator
22
23
from ..datasets import MapDataset
23
- from ..utils . log import logger
24
+ from ..losses import RDropLoss
24
25
from ..trainer import Trainer , TrainerCallback
25
26
from ..trainer .trainer_utils import EvalPrediction , get_scheduler
26
- from ..data import DataCollator
27
- from ..losses import RDropLoss
28
27
from ..transformers import PretrainedTokenizer , export_model
29
-
28
+ from ..utils .log import logger
29
+ from .prompt_args import PromptTuningArguments
30
+ from .prompt_utils import PromptDataCollatorWithPadding
30
31
from .template import AutoTemplate
31
32
from .verbalizer import SoftVerbalizer
32
- from .prompt_utils import signature , PromptDataCollatorWithPadding
33
- from .prompt_args import PromptTuningArguments
34
-
35
- __all__ = ["PromptTrainer" , "PromptModelForSequenceClassification" ]
36
33
37
34
38
35
class PromptTrainer (Trainer ):
@@ -43,10 +40,10 @@ class PromptTrainer(Trainer):
43
40
44
41
def __init__ (
45
42
self ,
46
- model : Union [ nn .Layer ] ,
43
+ model : nn .Layer ,
47
44
tokenizer : PretrainedTokenizer ,
48
- criterion : Union [nn .Layer ],
49
- args : PromptTuningArguments = None ,
45
+ criterion : Optional [nn .Layer ] = None ,
46
+ args : Optional [ PromptTuningArguments ] = None ,
50
47
data_collator : Optional [DataCollator ] = None ,
51
48
train_dataset : Optional [MapDataset ] = None ,
52
49
eval_dataset : Optional [MapDataset ] = None ,
@@ -64,6 +61,9 @@ def __init__(
64
61
if data_collator is None :
65
62
data_collator = PromptDataCollatorWithPadding (tokenizer , padding = True , return_tensors = "pd" )
66
63
64
+ if criterion is None and (args .use_rgl or args .use_rdrop ):
65
+ raise Exception ("'To use 'use_rgl', 'use_rdrop', 'criterion' must be specified" )
66
+
67
67
super (PromptTrainer , self ).__init__ (
68
68
model = model ,
69
69
criterion = criterion ,
@@ -175,7 +175,6 @@ def create_optimizer(self, lr_scheduler=None):
175
175
decay_parameters = [
176
176
p .name for n , p in self ._get_model ().named_parameters () if not any (nd in n for nd in ["bias" , "norm" ])
177
177
]
178
- apply_decay_param_fun = lambda x : x in decay_parameters
179
178
180
179
if len (plm_parameters ) > 0 :
181
180
ppt_lr = self .args .ppt_learning_rate / self .args .learning_rate
@@ -210,7 +209,7 @@ def create_optimizer(self, lr_scheduler=None):
210
209
211
210
self .optimizer = optim_cls (
212
211
learning_rate = lr ,
213
- apply_decay_param_fun = apply_decay_param_fun ,
212
+ apply_decay_param_fun = lambda x : x in decay_parameters ,
214
213
parameters = params ,
215
214
weight_decay = self .args .weight_decay ,
216
215
grad_clip = nn .ClipGradByGlobalNorm (self .args .max_grad_norm ),
@@ -228,21 +227,22 @@ def compute_loss(self, model, inputs, return_outputs=False):
228
227
labels = inputs ["labels" ]
229
228
230
229
input_dict = inputs .copy ()
231
- input_dict ["return_hidden_states" ] = True
232
- outputs , hidden_states = model (** input_dict )
233
230
234
231
if self .criterion is not None :
235
- loss = self .criterion (outputs , labels )
232
+ # pop labels to move loss computation out of the model
233
+ input_dict .pop ("labels" )
234
+ logits , hidden_states = model (** input_dict )
235
+ loss = self .criterion (logits , labels )
236
236
237
237
if self .args .use_rdrop :
238
- loss = self ._compute_rdrop_loss (model , input_dict , outputs , loss )
238
+ loss = self ._compute_rdrop_loss (model , input_dict , logits , loss )
239
239
240
240
if self .args .use_rgl :
241
241
loss += self ._compute_rgl_loss (hidden_states , labels )
242
+ else :
243
+ loss , logits , _ = model (** input_dict )
242
244
243
- outputs = (loss , outputs )
244
-
245
- loss = outputs ["loss" ] if isinstance (outputs , dict ) else outputs [0 ]
245
+ outputs = (loss , logits )
246
246
247
247
return (loss , outputs ) if return_outputs else loss
248
248
@@ -294,69 +294,3 @@ def export_model(self, export_path, input_spec, export_type="paddle"):
294
294
if self .verbalizer is not None :
295
295
self .verbalizer .save (export_path )
296
296
export_model (self .model , input_spec , export_path , export_type )
297
-
298
-
299
- class PromptModelForSequenceClassification (nn .Layer ):
300
- """
301
- PromptModel for classification tasks.
302
- """
303
-
304
- def __init__ (self , model , template , verbalizer = None , freeze_plm : bool = False , freeze_dropout : bool = False ):
305
- super (PromptModelForSequenceClassification , self ).__init__ ()
306
- self .plm = model
307
- self .template = template
308
- self .verbalizer = verbalizer
309
- self .freeze_plm = freeze_plm
310
- self .freeze_dropout = freeze_dropout
311
- if self .freeze_plm :
312
- for param in self .plm .parameters ():
313
- param .stop_gradient = True
314
- if self .freeze_dropout :
315
- self .plm .eval ()
316
- self .forward_keys = signature (self .plm .forward )
317
- self ._mask_token_id = self .template .tokenizer .mask_token_id
318
- self ._pad_token_id = self .template .tokenizer .pad_token_id
319
-
320
- def forward (
321
- self ,
322
- input_ids ,
323
- token_type_ids = None ,
324
- position_ids = None ,
325
- attention_mask = None ,
326
- masked_positions = None ,
327
- soft_token_ids = None ,
328
- encoder_ids = None ,
329
- ** kwargs
330
- ):
331
- input_dict = {
332
- "input_ids" : input_ids ,
333
- "token_type_ids" : token_type_ids ,
334
- "position_ids" : position_ids ,
335
- "masked_positions" : masked_positions ,
336
- "soft_token_ids" : soft_token_ids ,
337
- "attention_mask" : attention_mask ,
338
- "encoder_ids" : encoder_ids ,
339
- }
340
- input_dict = self .template .process_batch (input_dict )
341
- model_inputs = {k : input_dict [k ] for k in input_dict if k in self .forward_keys }
342
- if "masked_positions" in model_inputs :
343
- model_inputs .pop ("masked_positions" )
344
- outputs = self .plm (** model_inputs )
345
- if self .verbalizer is not None :
346
- label_outputs = self .verbalizer .process_outputs (outputs , input_dict ["masked_positions" ])
347
- else :
348
- label_outputs = outputs
349
-
350
- if kwargs .pop ("return_hidden_states" , False ):
351
- return label_outputs , outputs
352
- else :
353
- return label_outputs
354
-
355
- def prompt_parameters (self ):
356
- """
357
- Get the parameters of template and verbalizer.
358
- """
359
- params = [p for p in self .template .parameters ()]
360
- if self .verbalizer is not None :
361
- params += [p for p in self .verbalizer .parameters ()]
362
- return params
0 commit comments