Skip to content

Commit aad7a74

Browse files
authored
update prompt tuner (#38)
1 parent b5b897a commit aad7a74

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

swift/tuners/prompt.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class PromptConfig(SwiftConfig):
3131
attention_mask_value: The value to pad to the attention mask
3232
prompt_length: The length of the prompt tokens
3333
attach_front: When set to True, prompt is attached in front of the embedding
34+
extract_embedding: Whether the embedding is extracted at final stage to keep the same dims with inputs
3435
"""
3536

3637
dim: int = field(
@@ -60,6 +61,13 @@ class PromptConfig(SwiftConfig):
6061
'help':
6162
'When set to True, prompt is attached in front of the embedding'
6263
})
64+
65+
extract_embedding: bool = field(
66+
default=False,
67+
metadata={
68+
'help':
69+
'Whether the embedding is extracted at final stage to keep the same dims with inputs'
70+
})
6371

6472
def __post_init__(self):
6573
from .mapping import SwiftTuners
@@ -71,6 +79,7 @@ class Prompt:
7179
@staticmethod
7280
def prepare_model(model: nn.Module, config: PromptConfig):
7381
module_keys = [key for key, _ in model.named_modules()]
82+
match_module_keys = []
7483
for module_key in module_keys:
7584
if re.fullmatch(config.target_modules, module_key): # noqa
7685
module = model.get_submodule(module_key)
@@ -109,16 +118,26 @@ def _forward(self, *args, **kwargs):
109118
else:
110119
kwargs[config.attention_mask_pos] = attention_mask
111120

112-
return self.forward_origin(*args, **kwargs)
121+
forward_output = self.forward_origin(*args, **kwargs)
122+
if config.extract_embedding:
123+
forward_output = getattr(
124+
self, 'prompt').extract(forward_output)
125+
126+
return forward_output
113127

114128
module.forward_origin = module.forward
115129
module.forward = types.MethodType(_forward, module)
116-
prompt_module = PromptModule(config.dim,
130+
if isinstance(config.dim, list):
131+
input_dim = config.dim[len(match_module_keys)]
132+
else:
133+
input_dim = config.dim
134+
prompt_module = PromptModule(input_dim,
117135
int(module_key.rsplit('.')[-1]),
118136
config.prompt_length,
119137
config.attention_mask_value,
120138
config.attach_front)
121139
setattr(module, 'prompt', prompt_module)
140+
match_module_keys.append(module_key)
122141

123142
def state_dict_callback(state_dict):
124143
return {
@@ -184,3 +203,9 @@ def patch_attention_mask(self, m):
184203
prefix_attention_mask = torch.full((*m.shape[:-1], self.prompt_length),
185204
self.mask_values).to(m.device)
186205
return torch.cat((prefix_attention_mask, m), dim=-1)
206+
207+
def extract(self, x):
208+
if self.attach_front:
209+
return x[:, self.prompt_length:, :]
210+
else:
211+
return x[:, :-self.prompt_length, :]

0 commit comments

Comments
 (0)