Skip to content

Commit 490349a

Browse files
support part tuner replace_key False (#2438)
1 parent d13c431 commit 490349a

File tree

11 files changed

+22
-16
lines changed

11 files changed

+22
-16
lines changed

swift/tuners/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _feed_forward_chunk(self, attention_output):
111111
setattr(module, f'adapter_{adapter_name}', adapter_module)
112112
logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}')
113113

114-
def state_dict_callback(state_dict, adapter_name: str):
114+
def state_dict_callback(state_dict, adapter_name: str, **kwargs):
115115
return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key}
116116

117117
def mark_trainable_callback(model):

swift/tuners/llamapro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
7777
model.config.num_hidden_layers = len(new_module_list)
7878
LLaMAPro._set_module_list(config, model, new_module_list)
7979

80-
def state_dict_callback(state_dict, adapter_name):
80+
def state_dict_callback(state_dict, adapter_name, **kwargs):
8181
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
8282
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
8383
return {

swift/tuners/longlora/longlora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def prepare_model(model: nn.Module, config: LongLoRAConfig, adapter_name: str):
5151
"""Prepare a model with `LongLoRAConfig`"""
5252
LoraModel(model, config, adapter_name)
5353

54-
def state_dict_callback(state_dict, adapter_name):
54+
def state_dict_callback(state_dict, adapter_name, **kwargs):
5555
_state_dict = lora_state_dict(state_dict, adapter_name, config.bias)
5656
for name, value in state_dict.items():
5757
if isinstance(config.embedder_and_normalizer, str):

swift/tuners/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
8181
config.group_size = getattr(auto_gptq_config, 'group_size', None)
8282
LoraModel(model, config, adapter_name)
8383

84-
def state_dict_callback(state_dict, adapter_name, cfg=None):
84+
def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs):
8585
return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias)
8686

8787
def mark_trainable_callback(model, cfg=None):

swift/tuners/neftune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def neftune_hook(module, args, output):
4949
sub_module.register_forward_hook(neftune_hook)
5050
sub_module.nef_activated = True
5151

52-
def state_dict_callback(state_dict, adapter_name):
52+
def state_dict_callback(state_dict, adapter_name, **kwargs):
5353
return state_dict
5454

5555
def mark_trainable_callback(model):

swift/tuners/part.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ def _forward(self, *args, **kwargs):
7070
setattr(module, f'_part_{adapter_name}', new_module)
7171
new_module.requires_grad_(True)
7272

73-
def state_dict_callback(state_dict, adapter_name):
73+
def state_dict_callback(state_dict, adapter_name, **kwargs):
7474
new_state_dict = {}
7575
for key, value in state_dict.items():
7676
if f'_part_{adapter_name}.' in key:
77-
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
77+
if kwargs.get('replace_key', True):
78+
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
79+
else:
80+
new_key = key
7881
new_state_dict[new_key] = value
7982

8083
return new_state_dict
@@ -90,11 +93,14 @@ def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Di
9093
for param_name in state_dict:
9194
if param_name.startswith(name):
9295
end = param_name[len(name):]
93-
if hasattr(module, 'base_layer'):
94-
new_state_dict[name + f'.base_layer._part_{adapter_name}'
95-
+ end] = state_dict[param_name]
96+
if '_part_' not in param_name:
97+
if hasattr(module, 'base_layer'):
98+
new_state_dict[name + f'.base_layer._part_{adapter_name}'
99+
+ end] = state_dict[param_name]
100+
else:
101+
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
96102
else:
97-
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
103+
new_state_dict[param_name] = state_dict[param_name]
98104
return new_state_dict
99105

100106
return SwiftOutput(

swift/tuners/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _forward(self, *args, **kwargs):
126126
logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}')
127127
match_module_keys.append(module_key)
128128

129-
def state_dict_callback(state_dict, adapter_name):
129+
def state_dict_callback(state_dict, adapter_name, **kwargs):
130130
return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key}
131131

132132
def mark_trainable_callback(model):

swift/tuners/restuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _forward_restuning(self, origin_arg):
233233
if target_module_ins is None:
234234
raise Exception('Cannot match target modules')
235235

236-
def state_dict_callback(state_dict, adapter_name):
236+
def state_dict_callback(state_dict, adapter_name, **kwargs):
237237
return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}
238238

239239
def mark_trainable_callback(model):

swift/tuners/rome/rome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def prepare_model(model: nn.Module, config: RomeConfig, adapter_name: str):
7676
hparams = ROMEHyperParams.from_name(config.model_type)
7777
modified_keys = apply_rome_to_model(model, config.tokenizer, config.knowledge, hparams, config.batch_first)
7878

79-
def state_dict_callback(state_dict, adapter_name):
79+
def state_dict_callback(state_dict, adapter_name, **kwargs):
8080
return {key: value for key, value in state_dict.items() if key in modified_keys}
8181

8282
def mark_trainable_callback(model):

swift/tuners/scetuning/scetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _forward_decoder_mode(self, *args, **kwargs):
176176
if len(hint_module_ins_list) > 0:
177177
setattr(t_module, 'hint', hint_module_ins_list[tuner_id])
178178

179-
def state_dict_callback(state_dict, adapter_name):
179+
def state_dict_callback(state_dict, adapter_name, **kwargs):
180180
state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key}
181181
return state_dict_new
182182

0 commit comments

Comments
 (0)