21
21
import re
22
22
import traceback
23
23
from abc import abstractmethod
24
+ from functools import partial
24
25
from typing import Any , Dict , List , Optional
25
26
26
27
import numpy as np
32
33
from paddlenlp .utils .log import logger
33
34
34
35
from .prompt_tokenizer import MLMPromptTokenizer
36
+ from .prompt_utils import (
37
+ masked_lm_forward_with_past_key_values ,
38
+ sequence_classification_forward_with_past_key_values ,
39
+ )
35
40
36
41
__all__ = ["Template" , "ManualTemplate" , "SoftTemplate" , "PrefixTemplate" , "AutoTemplate" , "UTCTemplate" ]
37
42
@@ -263,8 +268,10 @@ def save(self, save_path):
263
268
if not os .path .exists (save_path ):
264
269
os .makedirs (save_path , exist_ok = True )
265
270
template_config_file = os .path .join (save_path , TEMPLATE_CONFIG_FILE )
271
+ template_class = self .__class__ .__name__
266
272
with open (template_config_file , "w" , encoding = "utf-8" ) as fp :
267
- fp .write (json .dumps (self ._prompt , ensure_ascii = False ))
273
+ fp .write (json .dumps (self ._prompt , ensure_ascii = False ) + "\n " )
274
+ fp .write (json .dumps ({"class" : template_class }, ensure_ascii = False ) + "\n " )
268
275
template_param_file = os .path .join (save_path , TEMPLATE_PARAMETER_FILE )
269
276
template_state_dict = self .state_dict ()
270
277
if len (template_state_dict ) > 0 :
@@ -709,36 +716,54 @@ def parse_soft_prompt(self):
709
716
raise ValueError ("Keyword `prefix` should locate at the beginning of template." )
710
717
part ["soft" ] = part ["prefix" ]
711
718
part .pop ("prefix" )
719
+ if "encoder" not in part :
720
+ part ["encoder" ] = "mlp"
712
721
prompt [index ] = part
713
722
714
723
self ._prompt = prompt
715
724
return super (PrefixTemplate , self ).parse_soft_prompt ()
716
725
726
+ def process_model (self , model ):
727
+ if model .__class__ .__name__ .endswith ("ForSequenceClassification" ):
728
+ model .forward = partial (sequence_classification_forward_with_past_key_values , self = model )
729
+ elif model .__class__ .__name__ .endswith ("ForMaskedLM" ):
730
+ model .forward = partial (masked_lm_forward_with_past_key_values , self = model )
731
+ return model
732
+
717
733
def process_batch (self , input_dict : Dict [str , Tensor ]) -> Dict [str , Tensor ]:
718
734
word_embeds = self .word_embeddings (input_dict ["input_ids" ])
735
+ batch_size , _ = input_dict ["soft_token_ids" ].shape
736
+
737
+ soft_token_ids = paddle .masked_select (input_dict ["soft_token_ids" ], input_dict ["soft_token_ids" ] > 0 )
738
+ soft_token_ids = soft_token_ids .reshape ([batch_size , - 1 ])
739
+ _ , soft_len = soft_token_ids .shape
740
+
741
+ token_type_ids = paddle .masked_select (input_dict ["token_type_ids" ], input_dict ["soft_token_ids" ] == 0 )
742
+ input_dict ["token_type_ids" ] = token_type_ids .reshape ([batch_size , - 1 ])
743
+ position_ids = paddle .masked_select (input_dict ["position_ids" ], input_dict ["soft_token_ids" ] == 0 )
744
+ input_dict ["position_ids" ] = position_ids .reshape ([batch_size , - 1 ])
745
+ if "masked_position" in input_dict and input_dict ["masked_positions" ] is not None :
746
+ input_dict ["masked_positions" ] = input_dict ["masked_positions" ] - soft_len
747
+ input_dict ["inputs_embeds" ] = paddle .concat (
748
+ [word_embeds [:, 0 , :].unsqueeze (1 ), word_embeds [:, soft_len + 1 :, :]], axis = 1
749
+ )
750
+
719
751
if "attention_mask" not in input_dict or input_dict ["attention_mask" ] is None :
720
752
pad_token_id = self .tokenizer .pad_token_id
721
753
attention_mask = paddle .unsqueeze (
722
754
(input_dict ["input_ids" ] == pad_token_id ).astype ("float32" ) * - 1e4 , axis = [1 , 2 ]
723
755
)
724
756
input_dict ["attention_mask" ] = attention_mask
725
757
input_dict ["input_ids" ] = None
726
-
727
- batch_size , _ = input_dict ["soft_token_ids" ].shape
728
- soft_token_ids = paddle .masked_select (input_dict ["soft_token_ids" ], input_dict ["soft_token_ids" ] > 0 )
729
- soft_token_ids = soft_token_ids .reshape ([batch_size , - 1 ])
730
- _ , soft_len = soft_token_ids .shape
731
-
732
- input_dict ["inputs_embeds" ] = word_embeds [:, soft_len :, :]
758
+ input_dict .pop ("soft_token_ids" )
759
+ input_dict .pop ("encoder_ids" )
733
760
734
761
soft_embeds = self .soft_embeddings (soft_token_ids )
735
- for encoder_id in range (1 , len (self .encoder_list )):
736
- to_encode = paddle .where (input_dict ["encoder_ids" ] == encoder_id )
737
- encoded = self .encoder_list [encoder_id ](to_encode )
738
- soft_embeds = paddle .where (input_dict ["encoder_ids" ] == encoder_id , encoded , soft_embeds )
762
+ soft_embeds = self .encoder_list [1 ](soft_embeds )
739
763
soft_embeds = soft_embeds .reshape (
740
764
[batch_size , soft_len , self .n_layer * 2 , self .n_heads , self .embed_size // self .n_heads ]
741
765
)
766
+
742
767
soft_embeds = self .dropout (soft_embeds )
743
768
soft_embeds = paddle .transpose (soft_embeds , perm = [2 , 0 , 3 , 1 , 4 ])
744
769
soft_embeds = paddle .split (soft_embeds , num_or_sections = self .n_layer )
@@ -776,6 +801,7 @@ def create_from(
776
801
model : PretrainedModel = None ,
777
802
soft_embeddings : Tensor = None ,
778
803
prefix_dropout : float = 0.1 ,
804
+ template_class : str = None ,
779
805
):
780
806
# Default template if not defined.
781
807
if prompt is None :
@@ -791,12 +817,20 @@ def create_from(
791
817
if "mask" not in template_keywords :
792
818
prompt = prompt + [{"mask" : None }]
793
819
820
+ if template_class is None :
821
+ if "prefix" in template_keywords :
822
+ template_class = "PrefixTemplate"
823
+ elif "soft" in template_keywords or "soft_id" in template_keywords :
824
+ template_class = "SoftTemplate"
825
+ else :
826
+ template_class = "ManualTemplate"
827
+
794
828
# Choose Template according to template keywords.
795
- if "prefix" in template_keywords :
829
+ if template_class == "PrefixTemplate" :
796
830
return PrefixTemplate (
797
831
prompt = prompt , tokenizer = tokenizer , max_length = max_length , model = model , prefix_dropout = prefix_dropout
798
832
)
799
- elif "soft" in template_keywords or "soft_id" in template_keywords :
833
+ elif template_class == "SoftTemplate" :
800
834
word_embeddings = model .get_input_embeddings ()
801
835
return SoftTemplate (
802
836
prompt = prompt ,
@@ -805,10 +839,12 @@ def create_from(
805
839
word_embeddings = word_embeddings ,
806
840
soft_embeddings = soft_embeddings ,
807
841
)
808
- elif "options" in template_keywords :
842
+ elif template_class == "UTCTemplate" :
809
843
return UTCTemplate (tokenizer = tokenizer , max_length = max_length )
810
- else :
844
+ elif template_class == "ManualTemplate" :
811
845
return ManualTemplate (prompt = prompt , tokenizer = tokenizer , max_length = max_length )
846
+ else :
847
+ raise ValueError (f"Unknown template: { template_class } ." )
812
848
813
849
@classmethod
814
850
def load_from (
@@ -818,9 +854,15 @@ def load_from(
818
854
if not os .path .isfile (template_config_file ):
819
855
raise ValueError ("{} not found under {}" .format (TEMPLATE_CONFIG_FILE , data_path ))
820
856
with open (template_config_file , "r" ) as fp :
821
- prompt = json .loads (fp .readline ().strip ())
822
- # TODO (Huijuan): Load all configs from data_path.
823
- template = cls .create_from (prompt = prompt , tokenizer = tokenizer , max_length = max_length , model = model )
857
+ config = [x .strip () for x in fp ]
858
+ prompt = json .loads (config [0 ])
859
+ if len (config ) > 1 :
860
+ template_class = json .loads (config [1 ])
861
+ else :
862
+ template_class = None # Compatible with previous versions
863
+ template = cls .create_from (
864
+ prompt = prompt , tokenizer = tokenizer , max_length = max_length , model = model , template_class = template_class
865
+ )
824
866
template_param_file = os .path .join (data_path , TEMPLATE_PARAMETER_FILE )
825
867
if os .path .isfile (template_param_file ):
826
868
template .set_state_dict (paddle .load (template_param_file ))
@@ -834,10 +876,14 @@ class UTCTemplate(Template):
834
876
835
877
template_special_tokens = ["text" , "hard" , "sep" , "cls" , "options" ]
836
878
837
- def __init__ (self , tokenizer : PretrainedTokenizer , max_length : int ):
879
+ def __init__ (self , tokenizer : PretrainedTokenizer , max_length : int , prompt : str = None ):
838
880
prompt = (
839
- "{'options': 'choices', 'add_omask': True, 'position': 0, 'token_type': 1}"
840
- "{'sep': None, 'token_type': 0, 'position': 0}{'text': 'text_a'}{'sep': None, 'token_type': 1}{'text': 'text_b'}"
881
+ (
882
+ "{'options': 'choices', 'add_omask': True, 'position': 0, 'token_type': 1}"
883
+ "{'sep': None, 'token_type': 0, 'position': 0}{'text': 'text_a'}{'sep': None, 'token_type': 1}{'text': 'text_b'}"
884
+ )
885
+ if prompt is None
886
+ else prompt
841
887
)
842
888
super (UTCTemplate , self ).__init__ (prompt , tokenizer , max_length )
843
889
self .max_position_id = self .tokenizer .model_max_length - 1
0 commit comments