Skip to content

Commit cf30fa8

Browse files
Support offload (#281)
1 parent 2c96e8e commit cf30fa8

File tree

11 files changed

+306
-83
lines changed

11 files changed

+306
-83
lines changed

swift/tuners/adapter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _feed_forward_chunk(self, attention_output):
131131
setattr(module, config.method_name,
132132
types.MethodType(_forward, module))
133133
adapter_module = AdapterModule(config.dim, adapter_name,
134+
module_key,
134135
config.adapter_length,
135136
ACT2CLS[config.act_layer])
136137
setattr(module, f'adapter_{adapter_name}', adapter_module)
@@ -152,13 +153,17 @@ def mark_trainable_callback(model):
152153
mark_trainable_callback)
153154

154155
@staticmethod
155-
def activate_adapter(module: torch.nn.Module, adapter_name: str,
156-
activate: bool):
157-
modules: List[torch.nn.Module] = find_sub_module(
158-
module, f'adapter_{adapter_name}')
156+
def activate_adapter(module: torch.nn.Module,
157+
adapter_name: str,
158+
activate: bool,
159+
offload: str = None):
160+
modules = find_sub_module(module, f'adapter_{adapter_name}')
159161
for _module in modules:
160162
_module: ActivationMixin
163+
_module: nn.Module
161164
_module.set_activation(adapter_name, activate)
165+
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key,
166+
activate, offload)
162167

163168

164169
class AdapterModule(nn.Module, ActivationMixin):
@@ -177,11 +182,12 @@ def __init__(
177182
self,
178183
dim,
179184
adapter_name,
185+
module_key,
180186
adapter_length=None,
181187
act_layer=nn.GELU,
182188
):
183189
super(AdapterModule, self).__init__()
184-
super(nn.Module, self).__init__()
190+
super(nn.Module, self).__init__(module_key)
185191
self.dim = dim
186192
self.adapter_name = adapter_name
187193
self.adapter_length = adapter_length

swift/tuners/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,9 @@ def save_pretrained(self,
432432
def base_model(self):
433433
return self.model
434434

435-
def set_active_adapters(self, adapter_names: Union[List[str], str]):
435+
def set_active_adapters(self,
436+
adapter_names: Union[List[str], str],
437+
offload=None):
436438
if not adapter_names:
437439
return
438440

@@ -444,7 +446,7 @@ def set_active_adapters(self, adapter_names: Union[List[str], str]):
444446
self.activate_adapter(adapter_name)
445447

446448
for adapter_name in (set(self.adapters.keys()) - adapter_names):
447-
self.deactivate_adapter(adapter_name)
449+
self.deactivate_adapter(adapter_name, offload)
448450

449451
def activate_adapter(self, adapter_name):
450452
if adapter_name not in self.adapters:
@@ -456,15 +458,15 @@ def activate_adapter(self, adapter_name):
456458
SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
457459
.activate_adapter(self.base_model, adapter_name, True)
458460

459-
def deactivate_adapter(self, adapter_name):
461+
def deactivate_adapter(self, adapter_name, offload=None):
460462
if adapter_name not in self.adapters:
461463
logger.warning(
462464
f'{adapter_name} not in adapters: {self.adapters.keys()}')
463465
return
464466

465467
from .mapping import SWIFT_MAPPING
466468
SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
467-
.activate_adapter(self.base_model, adapter_name, False)
469+
.activate_adapter(self.base_model, adapter_name, False, offload=offload)
468470

469471
def get_trainable_parameters(self):
470472
"""

swift/tuners/lora.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from packaging import version
8+
from peft.tuners.lora import LoraLayer
89

910
from swift import LoraConfig
1011
from .lora_layers import * # noqa
@@ -69,12 +70,15 @@ def mark_trainable_callback(model):
6970
mark_trainable_callback)
7071

7172
@staticmethod
72-
def activate_adapter(module: torch.nn.Module, adapter_name: str,
73-
activate: bool):
74-
set_adapter(module, adapter_name, activate)
73+
def activate_adapter(module: torch.nn.Module,
74+
adapter_name: str,
75+
activate: bool,
76+
offload: str = None):
77+
set_adapter(module, adapter_name, activate, offload)
7578
for sub_module in module.modules():
7679
if isinstance(sub_module, (LoraLayer, LoRALayer)):
7780
sub_module.set_activation(adapter_name, activate)
81+
sub_module.save_memory(adapter_name, activate, offload)
7882

7983
@staticmethod
8084
def unpatch_lora(model, config: LoRAConfig, adapter_name: str):

swift/tuners/lora_layers.py

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from peft.tuners.lora import Conv2d as _Conv2d
1717
from peft.tuners.lora import Embedding as _Embedding
1818
from peft.tuners.lora import Linear as _Linear
19-
from peft.tuners.lora import LoraLayer
2019
from peft.tuners.lora import LoraModel as _LoraModel
2120
from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear
2221
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -25,7 +24,7 @@
2524
from transformers import Conv1D
2625

2726
from swift import get_logger
28-
from .utils import ActivationMixin, ModulesToSaveWrapper
27+
from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter
2928

3029
logger = get_logger()
3130

@@ -52,7 +51,7 @@ def active_adapters(self):
5251
def active_adapter(self) -> str:
5352
return self.get_activated_adapters()
5453

55-
def set_adapter(self, adapter_names):
54+
def set_adapter(self, adapter_names, offload=None):
5655
if isinstance(adapter_names, str):
5756
adapter_names = [adapter_names]
5857

@@ -63,9 +62,28 @@ def set_adapter(self, adapter_names):
6362
if key in adapter_names:
6463
self.set_activation(key, True)
6564
layer.requires_grad_(True)
65+
SwiftAdapter.save_memory(layer, key, self.module_key, True)
6666
else:
6767
self.set_activation(key, False)
6868
layer.requires_grad_(False)
69+
SwiftAdapter.save_memory(
70+
layer, key, self.module_key, False, offload=offload)
71+
72+
def save_memory(self, adapter_name, activate, offload=None):
73+
for layer_name in self.adapter_layer_names:
74+
module_dict = getattr(self, layer_name)
75+
for key, layer in module_dict.items():
76+
if key == adapter_name:
77+
if activate:
78+
SwiftAdapter.save_memory(layer, layer_name + '.' + key,
79+
self.module_key, True)
80+
else:
81+
SwiftAdapter.save_memory(
82+
layer,
83+
layer_name + '.' + key,
84+
self.module_key,
85+
False,
86+
offload=offload)
6987

7088
def merge(self, *args, **kwargs):
7189
if not self.unique_thread:
@@ -85,9 +103,10 @@ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt):
85103
def __init__(
86104
self,
87105
*args,
106+
module_key: str,
88107
**kwargs,
89108
):
90-
super(Linear8bitLt, self).__init__()
109+
super(Linear8bitLt, self).__init__(module_key)
91110
self.set_activation(args[1], True)
92111
super(ActivationMixin, self).__init__(*args, **kwargs)
93112

@@ -100,9 +119,10 @@ class Linear4bit(LoRAActivationMixin, _Linear4bit):
100119
def __init__(
101120
self,
102121
*args,
122+
module_key: str,
103123
**kwargs,
104124
):
105-
super(Linear4bit, self).__init__()
125+
super(Linear4bit, self).__init__(module_key)
106126
self.set_activation(args[1], True)
107127
super(ActivationMixin, self).__init__(*args, **kwargs)
108128

@@ -117,9 +137,10 @@ def __init__(
117137
*args,
118138
use_qa_lora=False,
119139
group_size=None,
140+
module_key: str,
120141
**kwargs,
121142
):
122-
super(QuantLinear, self).__init__()
143+
super(QuantLinear, self).__init__(module_key)
123144
self.set_activation(args[1], True)
124145
super(ActivationMixin, self).__init__(*args, **kwargs)
125146
self.group_size = group_size
@@ -166,33 +187,34 @@ class Embedding(LoRAActivationMixin, _Embedding):
166187
def __init__(
167188
self,
168189
*args,
190+
module_key: str,
169191
**kwargs,
170192
) -> None:
171-
super(Embedding, self).__init__()
193+
super(Embedding, self).__init__(module_key)
172194
self.set_activation(args[1], True)
173195
super(ActivationMixin, self).__init__(*args, **kwargs)
174196

175197

176198
class Linear(LoRAActivationMixin, _Linear):
177199

178-
def __init__(self, *args, **kwargs):
179-
super(Linear, self).__init__()
200+
def __init__(self, *args, module_key: str, **kwargs):
201+
super(Linear, self).__init__(module_key)
180202
self.set_activation(args[1], True)
181203
super(ActivationMixin, self).__init__(*args, **kwargs)
182204

183205

184206
class Conv2d(LoRAActivationMixin, _Conv2d):
185207

186-
def __init__(self, *args, **kwargs):
187-
super(Conv2d, self).__init__()
208+
def __init__(self, *args, module_key: str, **kwargs):
209+
super(Conv2d, self).__init__(module_key)
188210
self.set_activation(args[1], True)
189211
super(ActivationMixin, self).__init__(*args, **kwargs)
190212

191213

192214
class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear):
193215

194-
def __init__(self, *args, **kwargs):
195-
super(LoraParallelLinear, self).__init__()
216+
def __init__(self, *args, module_key: str, **kwargs):
217+
super(LoraParallelLinear, self).__init__(module_key)
196218
self.set_activation(args[1], True)
197219
super(ActivationMixin, self).__init__(*args, **kwargs)
198220

@@ -249,7 +271,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
249271
parent, target, target_name = _get_submodules(model, key)
250272

251273
if not isinstance(target, ModulesToSaveWrapper):
252-
new_module = ModulesToSaveWrapper(target, adapter_name)
274+
new_module = ModulesToSaveWrapper(
275+
target, adapter_name, module_key=key)
253276
setattr(parent, target_name, new_module)
254277
else:
255278
target.update(adapter_name)
@@ -384,8 +407,12 @@ def _create_and_replace(
384407
)
385408
self._convert_dtype(target, lora_config.lora_dtype)
386409
else:
387-
new_module = self._create_new_module(lora_config, adapter_name,
388-
target, **kwargs)
410+
new_module = self._create_new_module(
411+
lora_config,
412+
adapter_name,
413+
target,
414+
current_key=current_key,
415+
**kwargs)
389416
if new_module is not None:
390417
if adapter_name != self.active_adapter:
391418
# adding an additional adapter: it is not automatically trainable
@@ -395,6 +422,7 @@ def _create_and_replace(
395422

396423
@staticmethod
397424
def _create_new_module(lora_config, adapter_name, target, **kwargs):
425+
current_key = kwargs.pop('current_key')
398426
gptq_quantization_config = kwargs.get('gptq_quantization_config', None)
399427
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(
400428
gptq_quantization_config)
@@ -422,7 +450,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
422450
'threshold': target.state.threshold,
423451
'index': target.index,
424452
})
425-
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
453+
new_module = Linear8bitLt(
454+
target,
455+
adapter_name,
456+
module_key=current_key,
457+
**eightbit_kwargs)
426458
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(
427459
target_base_layer, bnb.nn.Linear4bit):
428460
fourbit_kwargs = kwargs.copy()
@@ -434,19 +466,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
434466
'quant_type':
435467
target_base_layer.weight.quant_type,
436468
})
437-
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
469+
new_module = Linear4bit(
470+
target, adapter_name, module_key=current_key, **fourbit_kwargs)
438471
elif AutoGPTQQuantLinear is not None and isinstance(
439472
target_base_layer, AutoGPTQQuantLinear):
440-
new_module = QuantLinear(target, adapter_name, **kwargs)
473+
new_module = QuantLinear(
474+
target, adapter_name, module_key=current_key, **kwargs)
441475
target.qweight = target_base_layer.qweight
442476
elif isinstance(target_base_layer, torch.nn.Embedding):
443477
embedding_kwargs = kwargs.copy()
444478
embedding_kwargs.pop('fan_in_fan_out', None)
445479
embedding_kwargs.update(lora_config.loftq_config)
446-
new_module = Embedding(target, adapter_name, **embedding_kwargs)
480+
new_module = Embedding(
481+
target,
482+
adapter_name,
483+
module_key=current_key,
484+
**embedding_kwargs)
447485
elif isinstance(target_base_layer, torch.nn.Conv2d):
448486
kwargs.update(lora_config.loftq_config)
449-
new_module = Conv2d(target, adapter_name, **kwargs)
487+
new_module = Conv2d(
488+
target, adapter_name, module_key=current_key, **kwargs)
450489
elif lora_config.use_merged_linear:
451490
new_module = MergedLinear(
452491
adapter_name,
@@ -461,7 +500,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
461500
'Setting fan_in_fan_out to False.')
462501
kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False
463502
kwargs.update(lora_config.loftq_config)
464-
new_module = Linear(target, adapter_name, **kwargs)
503+
new_module = Linear(
504+
target, adapter_name, module_key=current_key, **kwargs)
465505
elif megatron_core and isinstance(
466506
target_base_layer, # noqa
467507
( # noqa
@@ -486,6 +526,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
486526
new_module = LoraParallelLinear(
487527
base_layer=target,
488528
adapter_name=adapter_name,
529+
module_key=current_key,
489530
backend=megatron_core.tensor_parallel,
490531
**megatron_kwargs)
491532
elif isinstance(target_base_layer, Conv1D):
@@ -496,7 +537,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
496537
kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True
497538
kwargs.update(lora_config.loftq_config)
498539
new_module = Linear(
499-
target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
540+
target,
541+
adapter_name,
542+
module_key=current_key,
543+
is_target_conv_1d_layer=True,
544+
**kwargs)
500545
else:
501546
logger.debug(
502547
f'Target module {target} is not supported. Currently, only the following modules are supported: '
@@ -512,12 +557,13 @@ class LoRALayer(ActivationMixin):
512557
def __init__(
513558
self,
514559
adapter_name: str,
560+
module_key: str,
515561
r: int,
516562
lora_alpha: int,
517563
lora_dropout: float,
518564
merge_weights: bool,
519565
):
520-
super().__init__()
566+
super().__init__(module_key)
521567
self.adapter_name = adapter_name
522568
self.r = r
523569
self.lora_alpha = lora_alpha
@@ -537,6 +583,7 @@ class MergedLinear(nn.Linear, LoRALayer):
537583
# LoRA implemented in a dense layer
538584
def __init__(self,
539585
adapter_name: str,
586+
module_key: str,
540587
base_layer: nn.Linear,
541588
r: int = 0,
542589
lora_alpha: int = 1,
@@ -558,6 +605,7 @@ def __init__(self,
558605
LoRALayer.__init__(
559606
self,
560607
adapter_name,
608+
module_key,
561609
r=r,
562610
lora_alpha=lora_alpha,
563611
lora_dropout=lora_dropout,

0 commit comments

Comments
 (0)