1616from peft .tuners .lora import Conv2d as _Conv2d
1717from peft .tuners .lora import Embedding as _Embedding
1818from peft .tuners .lora import Linear as _Linear
19- from peft .tuners .lora import LoraLayer
2019from peft .tuners .lora import LoraModel as _LoraModel
2120from peft .tuners .lora .tp_layer import LoraParallelLinear as _LoraParallelLinear
2221from peft .tuners .tuners_utils import BaseTunerLayer
2524from transformers import Conv1D
2625
2726from swift import get_logger
28- from .utils import ActivationMixin , ModulesToSaveWrapper
27+ from .utils import ActivationMixin , ModulesToSaveWrapper , SwiftAdapter
2928
3029logger = get_logger ()
3130
@@ -52,7 +51,7 @@ def active_adapters(self):
5251 def active_adapter (self ) -> str :
5352 return self .get_activated_adapters ()
5453
55- def set_adapter (self , adapter_names ):
54+ def set_adapter (self , adapter_names , offload = None ):
5655 if isinstance (adapter_names , str ):
5756 adapter_names = [adapter_names ]
5857
@@ -63,9 +62,28 @@ def set_adapter(self, adapter_names):
6362 if key in adapter_names :
6463 self .set_activation (key , True )
6564 layer .requires_grad_ (True )
65+ SwiftAdapter .save_memory (layer , key , self .module_key , True )
6666 else :
6767 self .set_activation (key , False )
6868 layer .requires_grad_ (False )
69+ SwiftAdapter .save_memory (
70+ layer , key , self .module_key , False , offload = offload )
71+
72+ def save_memory (self , adapter_name , activate , offload = None ):
73+ for layer_name in self .adapter_layer_names :
74+ module_dict = getattr (self , layer_name )
75+ for key , layer in module_dict .items ():
76+ if key == adapter_name :
77+ if activate :
78+ SwiftAdapter .save_memory (layer , layer_name + '.' + key ,
79+ self .module_key , True )
80+ else :
81+ SwiftAdapter .save_memory (
82+ layer ,
83+ layer_name + '.' + key ,
84+ self .module_key ,
85+ False ,
86+ offload = offload )
6987
7088 def merge (self , * args , ** kwargs ):
7189 if not self .unique_thread :
@@ -85,9 +103,10 @@ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt):
85103 def __init__ (
86104 self ,
87105 * args ,
106+ module_key : str ,
88107 ** kwargs ,
89108 ):
90- super (Linear8bitLt , self ).__init__ ()
109+ super (Linear8bitLt , self ).__init__ (module_key )
91110 self .set_activation (args [1 ], True )
92111 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
93112
@@ -100,9 +119,10 @@ class Linear4bit(LoRAActivationMixin, _Linear4bit):
100119 def __init__ (
101120 self ,
102121 * args ,
122+ module_key : str ,
103123 ** kwargs ,
104124 ):
105- super (Linear4bit , self ).__init__ ()
125+ super (Linear4bit , self ).__init__ (module_key )
106126 self .set_activation (args [1 ], True )
107127 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
108128
@@ -117,9 +137,10 @@ def __init__(
117137 * args ,
118138 use_qa_lora = False ,
119139 group_size = None ,
140+ module_key : str ,
120141 ** kwargs ,
121142 ):
122- super (QuantLinear , self ).__init__ ()
143+ super (QuantLinear , self ).__init__ (module_key )
123144 self .set_activation (args [1 ], True )
124145 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
125146 self .group_size = group_size
@@ -166,33 +187,34 @@ class Embedding(LoRAActivationMixin, _Embedding):
166187 def __init__ (
167188 self ,
168189 * args ,
190+ module_key : str ,
169191 ** kwargs ,
170192 ) -> None :
171- super (Embedding , self ).__init__ ()
193+ super (Embedding , self ).__init__ (module_key )
172194 self .set_activation (args [1 ], True )
173195 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
174196
175197
176198class Linear (LoRAActivationMixin , _Linear ):
177199
178- def __init__ (self , * args , ** kwargs ):
179- super (Linear , self ).__init__ ()
200+ def __init__ (self , * args , module_key : str , ** kwargs ):
201+ super (Linear , self ).__init__ (module_key )
180202 self .set_activation (args [1 ], True )
181203 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
182204
183205
184206class Conv2d (LoRAActivationMixin , _Conv2d ):
185207
186- def __init__ (self , * args , ** kwargs ):
187- super (Conv2d , self ).__init__ ()
208+ def __init__ (self , * args , module_key : str , ** kwargs ):
209+ super (Conv2d , self ).__init__ (module_key )
188210 self .set_activation (args [1 ], True )
189211 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
190212
191213
192214class LoraParallelLinear (LoRAActivationMixin , _LoraParallelLinear ):
193215
194- def __init__ (self , * args , ** kwargs ):
195- super (LoraParallelLinear , self ).__init__ ()
216+ def __init__ (self , * args , module_key : str , ** kwargs ):
217+ super (LoraParallelLinear , self ).__init__ (module_key )
196218 self .set_activation (args [1 ], True )
197219 super (ActivationMixin , self ).__init__ (* args , ** kwargs )
198220
@@ -249,7 +271,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
249271 parent , target , target_name = _get_submodules (model , key )
250272
251273 if not isinstance (target , ModulesToSaveWrapper ):
252- new_module = ModulesToSaveWrapper (target , adapter_name )
274+ new_module = ModulesToSaveWrapper (
275+ target , adapter_name , module_key = key )
253276 setattr (parent , target_name , new_module )
254277 else :
255278 target .update (adapter_name )
@@ -384,8 +407,12 @@ def _create_and_replace(
384407 )
385408 self ._convert_dtype (target , lora_config .lora_dtype )
386409 else :
387- new_module = self ._create_new_module (lora_config , adapter_name ,
388- target , ** kwargs )
410+ new_module = self ._create_new_module (
411+ lora_config ,
412+ adapter_name ,
413+ target ,
414+ current_key = current_key ,
415+ ** kwargs )
389416 if new_module is not None :
390417 if adapter_name != self .active_adapter :
391418 # adding an additional adapter: it is not automatically trainable
@@ -395,6 +422,7 @@ def _create_and_replace(
395422
396423 @staticmethod
397424 def _create_new_module (lora_config , adapter_name , target , ** kwargs ):
425+ current_key = kwargs .pop ('current_key' )
398426 gptq_quantization_config = kwargs .get ('gptq_quantization_config' , None )
399427 AutoGPTQQuantLinear = get_auto_gptq_quant_linear (
400428 gptq_quantization_config )
@@ -422,7 +450,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
422450 'threshold' : target .state .threshold ,
423451 'index' : target .index ,
424452 })
425- new_module = Linear8bitLt (target , adapter_name , ** eightbit_kwargs )
453+ new_module = Linear8bitLt (
454+ target ,
455+ adapter_name ,
456+ module_key = current_key ,
457+ ** eightbit_kwargs )
426458 elif loaded_in_4bit and is_bnb_4bit_available () and isinstance (
427459 target_base_layer , bnb .nn .Linear4bit ):
428460 fourbit_kwargs = kwargs .copy ()
@@ -434,19 +466,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
434466 'quant_type' :
435467 target_base_layer .weight .quant_type ,
436468 })
437- new_module = Linear4bit (target , adapter_name , ** fourbit_kwargs )
469+ new_module = Linear4bit (
470+ target , adapter_name , module_key = current_key , ** fourbit_kwargs )
438471 elif AutoGPTQQuantLinear is not None and isinstance (
439472 target_base_layer , AutoGPTQQuantLinear ):
440- new_module = QuantLinear (target , adapter_name , ** kwargs )
473+ new_module = QuantLinear (
474+ target , adapter_name , module_key = current_key , ** kwargs )
441475 target .qweight = target_base_layer .qweight
442476 elif isinstance (target_base_layer , torch .nn .Embedding ):
443477 embedding_kwargs = kwargs .copy ()
444478 embedding_kwargs .pop ('fan_in_fan_out' , None )
445479 embedding_kwargs .update (lora_config .loftq_config )
446- new_module = Embedding (target , adapter_name , ** embedding_kwargs )
480+ new_module = Embedding (
481+ target ,
482+ adapter_name ,
483+ module_key = current_key ,
484+ ** embedding_kwargs )
447485 elif isinstance (target_base_layer , torch .nn .Conv2d ):
448486 kwargs .update (lora_config .loftq_config )
449- new_module = Conv2d (target , adapter_name , ** kwargs )
487+ new_module = Conv2d (
488+ target , adapter_name , module_key = current_key , ** kwargs )
450489 elif lora_config .use_merged_linear :
451490 new_module = MergedLinear (
452491 adapter_name ,
@@ -461,7 +500,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
461500 'Setting fan_in_fan_out to False.' )
462501 kwargs ['fan_in_fan_out' ] = lora_config .fan_in_fan_out = False
463502 kwargs .update (lora_config .loftq_config )
464- new_module = Linear (target , adapter_name , ** kwargs )
503+ new_module = Linear (
504+ target , adapter_name , module_key = current_key , ** kwargs )
465505 elif megatron_core and isinstance (
466506 target_base_layer , # noqa
467507 ( # noqa
@@ -486,6 +526,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
486526 new_module = LoraParallelLinear (
487527 base_layer = target ,
488528 adapter_name = adapter_name ,
529+ module_key = current_key ,
489530 backend = megatron_core .tensor_parallel ,
490531 ** megatron_kwargs )
491532 elif isinstance (target_base_layer , Conv1D ):
@@ -496,7 +537,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
496537 kwargs ['fan_in_fan_out' ] = lora_config .fan_in_fan_out = True
497538 kwargs .update (lora_config .loftq_config )
498539 new_module = Linear (
499- target , adapter_name , is_target_conv_1d_layer = True , ** kwargs )
540+ target ,
541+ adapter_name ,
542+ module_key = current_key ,
543+ is_target_conv_1d_layer = True ,
544+ ** kwargs )
500545 else :
501546 logger .debug (
502547 f'Target module { target } is not supported. Currently, only the following modules are supported: '
@@ -512,12 +557,13 @@ class LoRALayer(ActivationMixin):
512557 def __init__ (
513558 self ,
514559 adapter_name : str ,
560+ module_key : str ,
515561 r : int ,
516562 lora_alpha : int ,
517563 lora_dropout : float ,
518564 merge_weights : bool ,
519565 ):
520- super ().__init__ ()
566+ super ().__init__ (module_key )
521567 self .adapter_name = adapter_name
522568 self .r = r
523569 self .lora_alpha = lora_alpha
@@ -537,6 +583,7 @@ class MergedLinear(nn.Linear, LoRALayer):
537583 # LoRA implemented in a dense layer
538584 def __init__ (self ,
539585 adapter_name : str ,
586+ module_key : str ,
540587 base_layer : nn .Linear ,
541588 r : int = 0 ,
542589 lora_alpha : int = 1 ,
@@ -558,6 +605,7 @@ def __init__(self,
558605 LoRALayer .__init__ (
559606 self ,
560607 adapter_name ,
608+ module_key ,
561609 r = r ,
562610 lora_alpha = lora_alpha ,
563611 lora_dropout = lora_dropout ,
0 commit comments