Skip to content

Commit 1463f74

Browse files
support ModuleToSave original module offloading (#282)
1 parent b441875 commit 1463f74

File tree

7 files changed

+114
-30
lines changed

7 files changed

+114
-30
lines changed

swift/tuners/base.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,21 @@ def __init__(self,
4949
model = model.base_model
5050

5151
if isinstance(config, SwiftConfig):
52-
self.adapters[DEFAULT_ADAPTER] = self._prepare_model(
53-
model, config, DEFAULT_ADAPTER)
52+
if DEFAULT_ADAPTER not in self.adapters:
53+
self.adapters[DEFAULT_ADAPTER] = self._prepare_model(
54+
model, config, DEFAULT_ADAPTER)
55+
else:
56+
logger.warn(
57+
f'Adater {DEFAULT_ADAPTER} has been patched, skip.')
5458
elif isinstance(config, dict):
5559
assert (all(isinstance(c, SwiftConfig) for c in config.values()))
5660
for adapter_name, _config in config.items():
57-
self.adapters[adapter_name] = self._prepare_model(
58-
model, _config, adapter_name)
61+
if adapter_name not in self.adapters:
62+
self.adapters[adapter_name] = self._prepare_model(
63+
model, _config, adapter_name)
64+
else:
65+
logger.warn(
66+
f'Adater {adapter_name} has been patched, skip.')
5967
self.model = model
6068

6169
self.extra_state_keys = extra_state_keys or []
@@ -195,7 +203,8 @@ def load_state_file(path):
195203
def from_pretrained(cls,
196204
model: Union[nn.Module, 'SwiftModel'],
197205
model_id: str = None,
198-
adapter_name: Union[str, List[str]] = None,
206+
adapter_name: Union[str, List[str], Dict[str,
207+
str]] = None,
199208
inference_mode: bool = False,
200209
revision: str = None,
201210
**kwargs):
@@ -205,7 +214,7 @@ def from_pretrained(cls,
205214
model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned,
206215
if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped..
207216
model_id (`str`): The model_id or a local model dir of tuners to use to tune the model.
208-
adapter_name (`Union[str, List[str]]`): The adapter_names saved in the model repo to load.
217+
adapter_name (`Union[str, List[str], Dict[str, str]]`): The adapter_names saved in the model repo to load.
209218
Default `None`, means load all tuners saved in the model_id
210219
inference_mode (`bool`): Use in the inference mode or not.
211220
revision (`str`): The model revision to use.
@@ -236,7 +245,8 @@ def from_pretrained(cls,
236245
os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME))
237246
]
238247
for _name in adapter_name if isinstance(adapter_name,
239-
list) else [adapter_name]:
248+
list) else [adapter_name] \
249+
if isinstance(adapter_name, str) else adapter_name.keys():
240250
sub_folder = os.path.join(model_dir, _name)
241251
config_file = os.path.join(sub_folder, CONFIG_NAME)
242252

@@ -250,26 +260,31 @@ def from_pretrained(cls,
250260
if SWIFT_TYPE_KEY not in json_object:
251261
raise ValueError('Mixed using with peft is not allowed now.')
252262
else:
253-
adapters[_name] = SwiftConfig.from_pretrained(sub_folder)
263+
key = _name if not isinstance(adapter_name,
264+
dict) else adapter_name[_name]
265+
adapters[key] = SwiftConfig.from_pretrained(sub_folder)
254266

255267
self = SwiftModel(model, adapters, extra_state_keys, inference_mode,
256268
**kwargs)
257269
for _name in adapter_name if isinstance(adapter_name,
258-
list) else [adapter_name]:
270+
list) else [adapter_name] \
271+
if isinstance(adapter_name, str) else adapter_name.keys():
259272
sub_folder = os.path.join(model_dir, _name)
260273
state_dict = cls.load_state_file(sub_folder)
274+
_adapter = _name if not isinstance(adapter_name,
275+
dict) else adapter_name[_name]
261276
if state_dict is not None:
262277
model_is_qlora = len([
263278
k for k in self.state_dict().keys()
264-
if k.endswith('.lora_A.default.weight')
265-
or k.endswith('.lora_B.default.weight')
279+
if k.endswith(f'.lora_A.{_adapter}.weight')
280+
or k.endswith(f'.lora_B.{_adapter}.weight')
266281
])
267282
if not model_is_qlora:
268283
# model is lora, state_dict: qlora->lora
269284
state_dict = {
270-
k[:-len('.default.weight') if k.
271-
endswith('.lora_A.default.weight') or k.
272-
endswith('.lora_B.default.weight') else None]: v
285+
k[:-len(f'.{_name}.weight') if k.
286+
endswith(f'.lora_A.{_name}.weight') or k.
287+
endswith(f'.lora_B.{_name}.weight') else None]: v
273288
for k, v in state_dict.items()
274289
}
275290
if any(['loramodule' in key for key in state_dict]):
@@ -288,7 +303,13 @@ def from_pretrained(cls,
288303
f'lora_B.{_name}.weight'): value
289304
for key, value in state_dict.items()
290305
}
291-
self.load_state_dict(state_dict, adapter_name=_name)
306+
if isinstance(adapter_name, dict):
307+
# TODO this logic is fragile! replace `_name` may cause other parts replaced
308+
state_dict = {
309+
key.replace(_name, adapter_name[_name]): value
310+
for key, value in state_dict.items()
311+
}
312+
self.load_state_dict(state_dict, adapter_name=_adapter)
292313
state_dict = cls.load_state_file(model_dir)
293314
if state_dict is not None:
294315
self.load_state_dict(state_dict)
@@ -569,7 +590,8 @@ def unmerge(model: Union[PeftModel, SwiftModel], **kwargs):
569590
@staticmethod
570591
def from_pretrained(model: Union[nn.Module, SwiftModel],
571592
model_id: str = None,
572-
adapter_name: Union[str, List[str]] = None,
593+
adapter_name: Union[str, List[str], Dict[str,
594+
str]] = None,
573595
revision: str = None,
574596
**kwargs):
575597
"""Prepare a model by a model_id in the ModelScope hub or a local dir.
@@ -593,7 +615,8 @@ def from_pretrained(model: Union[nn.Module, SwiftModel],
593615
is_peft_model = SWIFT_TYPE_KEY not in _json
594616

595617
_name = adapter_name if isinstance(
596-
adapter_name, str) or adapter_name is None else adapter_name[0]
618+
adapter_name, str) or adapter_name is None else adapter_name[0] \
619+
if isinstance(adapter_name, list) else list(adapter_name.keys())[0]
597620
_name = _name or ''
598621
if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)):
599622
with open(os.path.join(model_id, _name, CONFIG_NAME), 'r') as f:

swift/tuners/lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def activate_adapter(module: torch.nn.Module,
7878
for sub_module in module.modules():
7979
if isinstance(sub_module, (LoraLayer, LoRALayer)):
8080
sub_module.set_activation(adapter_name, activate)
81-
sub_module.save_memory(adapter_name, activate, offload)
81+
if hasattr(sub_module, 'save_memory'):
82+
sub_module.save_memory(adapter_name, activate, offload)
8283

8384
@staticmethod
8485
def unpatch_lora(model, config: LoRAConfig, adapter_name: str):

swift/tuners/lora_layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
272272

273273
if not isinstance(target, ModulesToSaveWrapper):
274274
new_module = ModulesToSaveWrapper(
275-
target, adapter_name, module_key=key)
275+
target, adapter_name=adapter_name, module_key=key)
276276
setattr(parent, target_name, new_module)
277277
else:
278278
target.update(adapter_name)
@@ -489,6 +489,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
489489
elif lora_config.use_merged_linear:
490490
new_module = MergedLinear(
491491
adapter_name,
492+
current_key,
492493
target,
493494
bias=bias,
494495
enable_lora=lora_config.enable_lora,

swift/tuners/neftune.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def mark_trainable_callback(model):
6464
mark_trainable_callback)
6565

6666
@staticmethod
67-
def activate_adapter(module: torch.nn.Module, adapter_name: str,
68-
activate: bool):
67+
def activate_adapter(module: torch.nn.Module,
68+
adapter_name: str,
69+
activate: bool,
70+
offload: str = None):
6971
for sub_module in module.modules():
7072
if isinstance(sub_module, torch.nn.Embedding):
7173
sub_module.nef_activated = activate

swift/tuners/scetuning/scetuning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ def _get_module(module):
154154

155155
# refactor forward function
156156
def _forward_encoder_mode(self, *args, **kwargs):
157-
args = self.forward_origin(*args, **kwargs)
157+
args = getattr(self, f'forward_origin_{adapter_name}')(*args,
158+
**kwargs)
158159
args_type = type(args)
159160
if args_type is tuple:
160161
args = args[0]
@@ -185,12 +186,15 @@ def _forward_decoder_mode(self, *args, **kwargs):
185186
if args_type is tuple:
186187
args_main = (args_sub_tuner_new, *args_sub_extra)
187188

188-
args_main = self.forward_origin(*args_main, **kwargs)
189+
args_main = getattr(self,
190+
f'forward_origin_{adapter_name}')(*args_main,
191+
**kwargs)
189192
return args_main
190193

191194
# 3. inject the tuners
192195
for tuner_id, t_module in enumerate(target_module_ins_list):
193-
t_module.forward_origin = getattr(t_module, 'forward')
196+
setattr(t_module, f'forward_origin_{adapter_name}',
197+
getattr(t_module, 'forward'))
194198
if config.tuner_mode in ('encoder', 'identity'):
195199
_forward = _forward_encoder_mode
196200
elif config.tuner_mode == 'decoder':

swift/tuners/utils.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import json
1313
import numpy as np
1414
import torch
15+
from packaging import version
1516
from peft.utils import CONFIG_NAME
1617
from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper
1718
from peft.utils import _get_submodules
@@ -252,7 +253,27 @@ def load_disk(module: torch.nn.Module, adapter_name, module_key):
252253
file = os.path.join(sub_folder, f'{key}.dat')
253254
state_dict[key] = OffloadHelper.load_offloaded_weight(
254255
file, OffloadHelper.index[md5][key])
255-
module.load_state_dict(state_dict, assign=True)
256+
if version.parse(torch.__version__) >= version.parse('2.1.0'):
257+
module.load_state_dict(state_dict, assign=True)
258+
else:
259+
for name, _module in module.named_modules():
260+
if len(list(_module.modules())) > 1:
261+
continue
262+
263+
buffers = {}
264+
prefix = name if not name else name + '.'
265+
for sub_name, buffer in _module.named_buffers():
266+
buffer_cls = type(buffer)
267+
buffers[sub_name] = buffer_cls(state_dict[prefix
268+
+ sub_name])
269+
_module._buffers.update(buffers)
270+
params = {}
271+
for sub_name, param in _module.named_parameters():
272+
param_cls = type(param)
273+
params[sub_name] = param_cls(
274+
state_dict[prefix + sub_name],
275+
requires_grad=param.requires_grad)
276+
_module._parameters.update(params)
256277
shutil.rmtree(sub_folder, ignore_errors=True)
257278

258279

@@ -295,7 +316,7 @@ def offload(module: torch.nn.Module, adapter_name, module_key,
295316
if offload == 'cpu':
296317
if str(device) != 'cpu':
297318
module.to('cpu')
298-
if offload == 'meta':
319+
elif offload == 'meta':
299320
if str(device) != 'meta':
300321
OffloadHelper.offload_disk(
301322
module, adapter_name=adapter_name, module_key=module_key)
@@ -331,6 +352,12 @@ class ModulesToSaveWrapper(ActivationMixin, _ModulesToSaveWrapper):
331352
def __init__(self, *args, module_key, **kwargs):
332353
super(ModulesToSaveWrapper, self).__init__(module_key)
333354
super(ActivationMixin, self).__init__(*args, **kwargs)
355+
SwiftAdapter.save_memory(
356+
self.original_module,
357+
'original_module',
358+
self.module_key,
359+
False,
360+
offload='cpu')
334361

335362
@property
336363
def active_adapter(self):
@@ -343,7 +370,7 @@ def active_adapter(self):
343370
)
344371
return active_adapters[0]
345372

346-
def set_adapter(self, adapter_name: str, offload: str):
373+
def set_adapter(self, adapter_name: str, offload: str = None):
347374
if adapter_name not in self.modules_to_save:
348375
raise ValueError(
349376
f'Adapter {adapter_name} not found in {self.modules_to_save.keys()}'
@@ -352,8 +379,14 @@ def set_adapter(self, adapter_name: str, offload: str):
352379
self.set_activation(adapter_name, True)
353380
SwiftAdapter.save_memory(self.modules_to_save[adapter_name],
354381
adapter_name, self.module_key, True)
382+
SwiftAdapter.save_memory(
383+
self.original_module,
384+
'original_module',
385+
self.module_key,
386+
False,
387+
offload=offload)
355388

356-
def deactivate_adapter(self, adapter_name: str, offload: str):
389+
def deactivate_adapter(self, adapter_name: str, offload: str = None):
357390
if adapter_name in self.modules_to_save and self.unique_thread:
358391
self.modules_to_save[adapter_name].requires_grad_(False)
359392
self.set_activation(adapter_name, False)
@@ -363,6 +396,22 @@ def deactivate_adapter(self, adapter_name: str, offload: str):
363396
self.module_key,
364397
False,
365398
offload=offload)
399+
if not self.get_activated_adapters():
400+
SwiftAdapter.save_memory(self.original_module, 'original_module',
401+
self.module_key, True)
402+
403+
def enable_adapters(self, enabled: bool):
404+
super().enable_adapters(enabled)
405+
if not enabled:
406+
SwiftAdapter.save_memory(
407+
self.original_module,
408+
'original_module',
409+
self.module_key,
410+
False,
411+
offload='meta')
412+
else:
413+
SwiftAdapter.save_memory(self.original_module, 'original_module',
414+
self.module_key, True)
366415

367416

368417
def set_adapter(model, adapter_name, activate, offload):
@@ -385,6 +434,7 @@ def set_trainable(model, adapter_name):
385434
target.update(adapter_name)
386435
target.set_adapter(target.active_adapter)
387436
else:
388-
new_module = ModulesToSaveWrapper(target, adapter_name)
437+
new_module = ModulesToSaveWrapper(
438+
target, module_key=key, adapter_name=adapter_name)
389439
new_module.set_adapter(adapter_name)
390440
setattr(parent, target_name, new_module)

tests/tuners/test_swift_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,12 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights):
202202
os.path.exists(
203203
os.path.join(self.tmp_dir, 'default', WEIGHTS_NAME)))
204204

205-
model2 = Swift.from_pretrained(model2, self.tmp_dir)
205+
model2 = Swift.from_pretrained(
206+
model2, self.tmp_dir, adapter_name={'default': 'test'})
207+
self.assertTrue('test' in model2.adapters)
206208
output2 = model2(**input)
207209
self.assertTrue(torch.allclose(output1.logits, output2.logits))
210+
model2 = Swift.from_pretrained(model2, self.tmp_dir)
208211
state_dict = model.state_dict()
209212
state_dict2 = model2.state_dict()
210213
for key in state_dict:

0 commit comments

Comments
 (0)