27
27
from paddlenlp .transformers import PretrainedTokenizer , PretrainedModel
28
28
from paddlenlp .utils .log import logger
29
29
30
- __all__ = ["Verbalizer" , "ManualVerbalizer" , "SoftVerbalizer" ]
30
+ __all__ = [
31
+ "Verbalizer" , "ManualVerbalizer" , "SoftVerbalizer" , "MaskedLMVerbalizer"
32
+ ]
31
33
32
34
# Verbalizer used to be saved in a file.
33
35
VERBALIZER_CONFIG_FILE = "verbalizer_config.json"
@@ -263,9 +265,11 @@ class ManualVerbalizer(Verbalizer):
263
265
An instance of PretrainedTokenizer for label word tokenization.
264
266
"""
265
267
266
- def __init__ (self , label_words : Dict , tokenizer : PretrainedTokenizer ):
268
+ def __init__ (self , label_words : Dict , tokenizer : PretrainedTokenizer ,
269
+ ** kwargs ):
267
270
super (ManualVerbalizer , self ).__init__ (label_words = label_words ,
268
- tokenizer = tokenizer )
271
+ tokenizer = tokenizer ,
272
+ ** kwargs )
269
273
270
274
def create_parameters (self ):
271
275
return None
@@ -292,10 +296,7 @@ def aggregate_multiple_mask(self, outputs: Tensor, atype: str = None):
292
296
"tokens." .format (atype ))
293
297
return outputs
294
298
295
- def process_outputs (self ,
296
- outputs : Tensor ,
297
- masked_positions : Tensor = None ,
298
- ** kwargs ):
299
+ def process_outputs (self , outputs : Tensor , masked_positions : Tensor = None ):
299
300
"""
300
301
Process outputs over the vocabulary, including the following steps:
301
302
@@ -364,10 +365,11 @@ class SoftVerbalizer(Verbalizer):
364
365
LAST_LINEAR = ["AlbertForMaskedLM" , "RobertaForMaskedLM" ]
365
366
366
367
def __init__ (self , label_words : Dict , tokenizer : PretrainedTokenizer ,
367
- model : PretrainedModel ):
368
+ model : PretrainedModel , ** kwargs ):
368
369
super (SoftVerbalizer , self ).__init__ (label_words = label_words ,
369
370
tokenizer = tokenizer ,
370
- model = model )
371
+ model = modeli ,
372
+ ** kwargs )
371
373
del self .model
372
374
setattr (model , self .head_name [0 ], MaskedLMIdentity ())
373
375
@@ -472,3 +474,63 @@ def _create_init_weight(self, weight, is_bias=False):
472
474
axis = 1 ).reshape (word_shape )
473
475
weight = self .aggregate (weight , token_mask , aggr_type )
474
476
return weight
477
+
478
+
479
+ class MaskedLMVerbalizer (Verbalizer ):
480
+ """
481
+ MaskedLMVerbalizer defines mapping from labels to words manually and supports
482
+ multiple masks corresponding to multiple tokens in words.
483
+
484
+ Args:
485
+ label_words (`dict`):
486
+ Define the mapping from labels to a single word. Only the first word
487
+ is used if multiple words are defined.
488
+ tokenizer (`PretrainedTokenizer`):
489
+ An instance of PretrainedTokenizer for label word tokenization.
490
+ """
491
+
492
+ def __init__ (self , label_words : Dict , tokenizer : PretrainedTokenizer ,
493
+ ** kwargs ):
494
+ super (MaskedLMVerbalizer , self ).__init__ (label_words = label_words ,
495
+ tokenizer = tokenizer ,
496
+ ** kwargs )
497
+
498
+ def create_parameters (self ):
499
+ return None
500
+
501
+ def aggregate_multiple_mask (self , outputs : Tensor , atype : str = "product" ):
502
+ assert outputs .ndim == 3
503
+ token_ids = self .token_ids [:, 0 , :].T
504
+ batch_size , num_token , num_pred = outputs .shape
505
+ results = paddle .index_select (outputs [:, 0 , :], token_ids [0 ], axis = 1 )
506
+ if atype == "first" :
507
+ return results
508
+
509
+ for index in range (1 , num_token ):
510
+ sub_results = paddle .index_select (outputs [:, index , :],
511
+ token_ids [index ],
512
+ axis = 1 )
513
+ if atype in ("mean" , "sum" ):
514
+ results += sub_results
515
+ elif atype == "product" :
516
+ results *= sub_results
517
+ elif atype == "max" :
518
+ results = paddle .stack ([results , sub_results ], axis = - 1 )
519
+ results = results .max (axis = - 1 )
520
+ else :
521
+ raise ValueError (
522
+ "Strategy {} is not supported to aggregate multiple "
523
+ "tokens." .format (atype ))
524
+ if atype == "mean" :
525
+ results = results / num_token
526
+ return results
527
+
528
+ def process_outputs (self , outputs : Tensor , masked_positions : Tensor = None ):
529
+ if masked_positions is None :
530
+ return outputs
531
+
532
+ batch_size , _ , num_pred = outputs .shape
533
+ outputs = outputs .reshape ([- 1 , num_pred ])
534
+ outputs = paddle .gather (outputs , masked_positions )
535
+ outputs = outputs .reshape ([batch_size , - 1 , num_pred ])
536
+ return outputs
0 commit comments