@@ -31,6 +31,7 @@ class PromptConfig(SwiftConfig):
3131 attention_mask_value: The value to pad to the attention mask
3232 prompt_length: The length of the prompt tokens
3333 attach_front: When set to True, prompt is attached in front of the embedding
34+ extract_embedding: Whether the embedding is extracted at final stage to keep the same dims with inputs
3435 """
3536
3637 dim : int = field (
@@ -60,6 +61,13 @@ class PromptConfig(SwiftConfig):
6061 'help' :
6162 'When set to True, prompt is attached in front of the embedding'
6263 })
64+
65+ extract_embedding : bool = field (
66+ default = False ,
67+ metadata = {
68+ 'help' :
69+ 'Whether the embedding is extracted at final stage to keep the same dims with inputs'
70+ })
6371
6472 def __post_init__ (self ):
6573 from .mapping import SwiftTuners
@@ -71,6 +79,7 @@ class Prompt:
7179 @staticmethod
7280 def prepare_model (model : nn .Module , config : PromptConfig ):
7381 module_keys = [key for key , _ in model .named_modules ()]
82+ match_module_keys = []
7483 for module_key in module_keys :
7584 if re .fullmatch (config .target_modules , module_key ): # noqa
7685 module = model .get_submodule (module_key )
@@ -109,16 +118,26 @@ def _forward(self, *args, **kwargs):
109118 else :
110119 kwargs [config .attention_mask_pos ] = attention_mask
111120
112- return self .forward_origin (* args , ** kwargs )
121+ forward_output = self .forward_origin (* args , ** kwargs )
122+ if config .extract_embedding :
123+ forward_output = getattr (
124+ self , 'prompt' ).extract (forward_output )
125+
126+ return forward_output
113127
114128 module .forward_origin = module .forward
115129 module .forward = types .MethodType (_forward , module )
116- prompt_module = PromptModule (config .dim ,
130+ if isinstance (config .dim , list ):
131+ input_dim = config .dim [len (match_module_keys )]
132+ else :
133+ input_dim = config .dim
134+ prompt_module = PromptModule (input_dim ,
117135 int (module_key .rsplit ('.' )[- 1 ]),
118136 config .prompt_length ,
119137 config .attention_mask_value ,
120138 config .attach_front )
121139 setattr (module , 'prompt' , prompt_module )
140+ match_module_keys .append (module_key )
122141
123142 def state_dict_callback (state_dict ):
124143 return {
@@ -184,3 +203,9 @@ def patch_attention_mask(self, m):
184203 prefix_attention_mask = torch .full ((* m .shape [:- 1 ], self .prompt_length ),
185204 self .mask_values ).to (m .device )
186205 return torch .cat ((prefix_attention_mask , m ), dim = - 1 )
206+
207+ def extract (self , x ):
208+ if self .attach_front :
209+ return x [:, self .prompt_length :, :]
210+ else :
211+ return x [:, :- self .prompt_length , :]
0 commit comments