Skip to content

Commit 87b6355

Browse files
committed
Fix model saving in new format (#198)
(cherry picked from commit c6bf657)
1 parent acf3665 commit 87b6355

File tree

5 files changed

+293
-13
lines changed

5 files changed

+293
-13
lines changed

swift/tuners/base.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import os
55
import re
6+
from copy import copy
67
from inspect import Parameter, Signature, signature
78
from types import MethodType
89
from typing import Dict, List, Union
@@ -118,9 +119,35 @@ def forward(self, *args, **kwargs):
118119
else:
119120
for output in self.adapters.values():
120121
output.mark_trainable_callback(model)
121-
122-
def load_state_dict(self, state_dict, strict=True):
123-
return self.model.load_state_dict(state_dict, False)
122+
if self.extra_state_keys:
123+
for n, p in model.named_parameters():
124+
if any(
125+
re.fullmatch(extra_key, n)
126+
for extra_key in self.extra_state_keys):
127+
p.requires_grad = True
128+
129+
def load_state_dict(self,
130+
state_dict,
131+
strict=True,
132+
adapter_name: str = None):
133+
if adapter_name is not None:
134+
output = self.adapters[adapter_name]
135+
if getattr(output.config, 'modules_to_save', None):
136+
for key, value in copy(state_dict).items():
137+
for module_name in output.config.modules_to_save:
138+
if module_name in key:
139+
state_dict.pop(key)
140+
key = key.replace(
141+
module_name,
142+
f'{module_name}.modules_to_save.{adapter_name}'
143+
)
144+
break
145+
state_dict[key] = value
146+
incompatible_keys = self.model.load_state_dict(state_dict, False)
147+
if len(incompatible_keys[1]) > 0:
148+
logger.error(
149+
f'Load state dict with unexpected keys: {incompatible_keys[1]}'
150+
)
124151

125152
def state_dict(self,
126153
*args,
@@ -149,18 +176,28 @@ def state_dict(self,
149176
Returns:
150177
The state dict to be saved.
151178
"""
152-
destination = self.model.state_dict(
179+
state_dict = self.model.state_dict(
153180
destination=destination, prefix=prefix, keep_vars=keep_vars)
154181
state_dicts = {}
155182
if kwargs.get('save_adapter', True):
156183
for name, output in self.adapters.items():
157184
if adapter_name == name or adapter_name is None:
158185
state_dicts.update(
159-
output.state_dict_callback(destination, name))
186+
output.state_dict_callback(state_dict, name))
187+
modules_to_save_names = [
188+
sub_name
189+
for sub_name, _ in self.model.named_parameters()
190+
if 'modules_to_save' in sub_name
191+
]
192+
for module_name in modules_to_save_names:
193+
if f'modules_to_save.{name}' in module_name:
194+
state_dicts[module_name.replace(
195+
f'modules_to_save.{name}.',
196+
'')] = state_dict[module_name]
160197
if kwargs.get('save_extra_states', True):
161198
state_dicts.update({
162199
k: v
163-
for k, v in destination.items() if any(
200+
for k, v in state_dict.items() if any(
164201
re.fullmatch(extra_key, k)
165202
for extra_key in self.extra_state_keys)
166203
})
@@ -289,10 +326,10 @@ def from_pretrained(cls,
289326
f'lora_B.{_name}.weight'): value
290327
for key, value in state_dict.items()
291328
}
292-
self.model.load_state_dict(state_dict, strict=False)
329+
self.load_state_dict(state_dict, adapter_name=_name)
293330
state_dict = cls.load_state_file(model_dir)
294331
if state_dict is not None:
295-
self.model.load_state_dict(state_dict, strict=False)
332+
self.load_state_dict(state_dict)
296333
return self
297334

298335
@classmethod
@@ -597,7 +634,7 @@ def from_pretrained(model: Union[nn.Module, SwiftModel],
597634
if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)):
598635
with open(os.path.join(model_id, _name, CONFIG_NAME), 'r') as f:
599636
_json = json.load(f)
600-
is_peft_model = SWIFT_TYPE_KEY not in _json
637+
is_peft_model = SWIFT_TYPE_KEY not in _json and 'extra_state_keys' not in _json
601638
if is_peft_model:
602639
return PeftModel.from_pretrained(
603640
model,

swift/tuners/lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from swift import LoraConfig
1010
from .lora_layers import * # noqa
11-
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
11+
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput, set_adapter
1212

1313
logger = get_logger()
1414

@@ -64,6 +64,7 @@ def mark_trainable_callback(model):
6464
@staticmethod
6565
def activate_adapter(module: torch.nn.Module, adapter_name: str,
6666
activate: bool):
67+
set_adapter(module, adapter_name, activate)
6768
for sub_module in module.modules():
6869
if isinstance(sub_module, (LoraLayer, LoRALayer)):
6970
sub_module.set_activation(adapter_name, activate)

swift/tuners/lora_layers.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from peft.tuners.lora import Linear as _Linear
1818
from peft.tuners.lora import LoraLayer
1919
from peft.tuners.lora import LoraModel as _LoraModel
20-
from peft.utils import get_auto_gptq_quant_linear, get_quantization_config
20+
from peft.utils import (_get_submodules, get_auto_gptq_quant_linear,
21+
get_quantization_config)
2122
from transformers import Conv1D
2223

2324
from swift import get_logger
24-
from .utils import ActivationMixin
25+
from .utils import ActivationMixin, ModulesToSaveWrapper
2526

2627
logger = get_logger()
2728

@@ -202,6 +203,87 @@ def __init__(self, model, config, adapter_name):
202203
nn.Module.__init__(self)
203204
self.model = model
204205

206+
def inject_adapter(self, model: nn.Module, adapter_name: str):
207+
r"""
208+
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
209+
hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed.
210+
211+
The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class.
212+
213+
Args:
214+
model (`nn.Module`):
215+
The model to be tuned.
216+
adapter_name (`str`):
217+
The adapter name.
218+
"""
219+
peft_config = self.peft_config[adapter_name]
220+
# Note: If possible, all checks should be performed *at the start of this method*.
221+
# This way, we can raise early if something goes wrong, without leaving the model
222+
# in a bad (half-initialized) state.
223+
self._check_new_adapter_config(peft_config)
224+
225+
is_target_modules_in_base_model = False
226+
key_list = [key for key, _ in model.named_modules()]
227+
228+
_check_for_modules_to_save = getattr(peft_config, 'modules_to_save',
229+
None) is not None
230+
_has_modules_to_save = False
231+
232+
model_config = getattr(model, 'config', {'model_type': 'custom'})
233+
if hasattr(model_config, 'to_dict'):
234+
model_config = model_config.to_dict()
235+
236+
peft_config = self._prepare_adapter_config(peft_config, model_config)
237+
238+
for key in key_list:
239+
# Check for modules_to_save in case
240+
if _check_for_modules_to_save and any(
241+
key.endswith(f'{module_to_save}')
242+
for module_to_save in peft_config.modules_to_save):
243+
# Optionally set the modules to save
244+
parent, target, target_name = _get_submodules(model, key)
245+
246+
if not isinstance(target, ModulesToSaveWrapper):
247+
new_module = ModulesToSaveWrapper(target, adapter_name)
248+
setattr(parent, target_name, new_module)
249+
else:
250+
target.update(adapter_name)
251+
252+
_has_modules_to_save = True
253+
continue
254+
255+
if not self._check_target_module_exists(peft_config, key):
256+
continue
257+
258+
is_target_modules_in_base_model = True
259+
parent, target, target_name = _get_submodules(model, key)
260+
261+
optional_kwargs = {
262+
'loaded_in_8bit': getattr(model, 'is_loaded_in_8bit', False),
263+
'loaded_in_4bit': getattr(model, 'is_loaded_in_4bit', False),
264+
'current_key': key,
265+
}
266+
self._create_and_replace(peft_config, adapter_name, target,
267+
target_name, parent, **optional_kwargs)
268+
269+
if not is_target_modules_in_base_model:
270+
raise ValueError(
271+
f'Target modules {peft_config.target_modules} not found in the base model. '
272+
f'Please check the target modules and try again.')
273+
274+
self._mark_only_adapters_as_trainable()
275+
276+
if self.peft_config[adapter_name].inference_mode:
277+
for n, p in self.model.named_parameters():
278+
if adapter_name in n:
279+
p.requires_grad = False
280+
281+
if _has_modules_to_save:
282+
if not hasattr(model, 'modules_to_save'):
283+
model.modules_to_save = set(peft_config.modules_to_save)
284+
else:
285+
model.modules_to_save.update(set(peft_config.modules_to_save))
286+
205287
def _create_and_replace(
206288
self,
207289
lora_config,

swift/tuners/utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,20 @@
55
import threading
66
from dataclasses import asdict, dataclass, field
77
from types import FunctionType
8-
from typing import Dict
8+
from typing import Dict, List, Optional
99

1010
import json
11+
import peft.utils
1112
import torch
1213
from peft.utils import CONFIG_NAME
14+
from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper
15+
from peft.utils import _get_submodules
1316

1417
from swift.hub.snapshot_download import snapshot_download
1518
from swift.utils.constants import BIN_EXTENSIONS
19+
from swift.utils.logger import get_logger
20+
21+
logger = get_logger()
1622

1723

1824
@dataclass
@@ -138,6 +144,10 @@ def __init__(self):
138144
self._thread_inf: Dict[int, Dict[str, bool]] = {}
139145
self._unique_thread = bool(
140146
int(os.environ.get(ActivationMixin.USE_UNIQUE_THREAD, '1')))
147+
if not self._unique_thread:
148+
logger.info(
149+
'Using multiple thread mode, gradient checkpointing is not supported.'
150+
)
141151

142152
@property
143153
def indent(self):
@@ -180,3 +190,59 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str,
180190
@staticmethod
181191
def freeze_model():
182192
return True
193+
194+
195+
class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper):
196+
197+
def __init__(self, *args, **kwargs):
198+
super(ModulesToSaveWrapper, self).__init__()
199+
super(ActivationMixin, self).__init__(*args, **kwargs)
200+
201+
@property
202+
def active_adapter(self):
203+
active_adapters = self.get_activated_adapters()
204+
if not active_adapters:
205+
return None
206+
elif len(active_adapters) > 1:
207+
raise ValueError(
208+
'ModulesToSaveWrapper does not support multiple active adapters'
209+
)
210+
return active_adapters[0]
211+
212+
def set_adapter(self, adapter_name: str):
213+
if adapter_name not in self.modules_to_save:
214+
raise ValueError(
215+
f'Adapter {adapter_name} not found in {self.modules_to_save.keys()}'
216+
)
217+
self.modules_to_save[adapter_name].requires_grad_(True)
218+
self.set_activation(adapter_name, True)
219+
220+
def deactivate_adapter(self, adapter_name: str):
221+
if adapter_name in self.modules_to_save and self.unique_thread:
222+
self.modules_to_save[adapter_name].requires_grad_(False)
223+
self.set_activation(adapter_name, False)
224+
225+
226+
def set_adapter(model, adapter_name, activate):
227+
for module in model.modules():
228+
if isinstance(module, ModulesToSaveWrapper):
229+
if activate:
230+
module.set_adapter(adapter_name)
231+
else:
232+
module.deactivate_adapter(adapter_name)
233+
234+
235+
def set_trainable(model, adapter_name):
236+
key_list = [key for key, _ in model.named_modules()]
237+
for key in key_list:
238+
target_module_found = any(
239+
key.endswith(target_key) for target_key in model.modules_to_save)
240+
if target_module_found:
241+
parent, target, target_name = _get_submodules(model, key)
242+
if isinstance(target, ModulesToSaveWrapper):
243+
target.update(adapter_name)
244+
target.set_adapter(target.active_adapter)
245+
else:
246+
new_module = ModulesToSaveWrapper(target, adapter_name)
247+
new_module.set_adapter(adapter_name)
248+
setattr(parent, target_name, new_module)

0 commit comments

Comments
 (0)