@@ -165,7 +165,7 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
165165 """Prepare a model with `LoRAConfig`"""
166166 LoRA ._dynamic_patch_lora (
167167 model ,
168- replace_modules = config .target_modules ,
168+ target_modules = config .target_modules ,
169169 r = config .r ,
170170 adapter_name = adapter_name ,
171171 lora_alpha = config .lora_alpha ,
@@ -195,32 +195,32 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str,
195195
196196 @staticmethod
197197 def _dynamic_patch_lora (model : torch .nn .Module ,
198- replace_modules : Union [str , List [str ]],
198+ target_modules : Union [str , List [str ]],
199199 use_merged_linear : bool , adapter_name : str ,
200200 ** kwargs ):
201201 """Dynamic patch lora to model
202202
203203 Args:
204204 model(`torch.nn.Module`): The torch.nn.Module containing the target module to be patched.
205- replace_modules (`Union[str, List[str]]`): The module names to be replaced,
205+ target_modules (`Union[str, List[str]]`): The module names to be replaced,
206206 the replacing strategy is `end with`.
207207 use_merged_linear(bool): Whether to replace with merged linear layer.
208208 adapter_name(str): The adapter name.
209209 **kwargs: The arguments passed from `tune` which are needed by lora.
210210 """
211211 modules = {}
212212 module_keys = [key for key , _ in model .named_modules ()]
213- assert isinstance (replace_modules , (str , list ))
213+ assert isinstance (target_modules , (str , list ))
214214 AutoGPTQQuantLinear = get_auto_gptq_quant_linear (
215215 get_quantization_config (model , method = 'gptq' ))
216216
217217 for module_key in module_keys :
218- if isinstance (replace_modules , str ):
219- target_module_found = re .fullmatch (replace_modules , module_key )
218+ if isinstance (target_modules , str ):
219+ target_module_found = re .fullmatch (target_modules , module_key )
220220 else :
221221 target_module_found = any (
222222 module_key .endswith (target_key )
223- for target_key in replace_modules )
223+ for target_key in target_modules )
224224 if target_module_found : # noqa
225225 sub_module = model .get_submodule (module_key )
226226
@@ -333,71 +333,38 @@ def _forward(self, *args, **kwargs):
333333 logger .debug (f'Lora modules(module_key -> adapter_name): { modules } ' )
334334
335335 @staticmethod
336- def unpatch_lora (model , config : LoRAConfig ):
336+ def unpatch_lora (model , config : LoRAConfig , adapter_name : str ):
337337 """Unpatch lora modules and merge the weights to original modules.
338338
339339 LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
340340 'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
341341 See https://arxiv.org/abs/2106.09685
342342
343343 Args:
344- model: The model called with `tune` function.
345- config: The `LoRAConfig` to use.
344+ model(`torch.nn.Module`): The model called with `tune` function.
345+ config(`LoRAConfig`): The `LoRAConfig` to use.
346+ adapter_name(`str`): The adapter name
346347 """
347348 module_keys = [key for key , _ in model .named_modules ()]
348- assert isinstance (config .replace_modules , (str , list ))
349- replace_modules = config .replace_modules
349+ assert isinstance (config .target_modules , (str , list ))
350+ target_modules = config .target_modules
350351
351352 for module_key in module_keys :
352- if isinstance (replace_modules , str ):
353- target_module_found = re .fullmatch (replace_modules , module_key )
353+ if isinstance (target_modules , str ):
354+ target_module_found = re .fullmatch (target_modules , module_key )
354355 else :
355356 target_module_found = any (
356357 module_key .endswith (target_key )
357- for target_key in replace_modules )
358+ for target_key in target_modules )
358359 if target_module_found : # noqa
359- parts = module_key .split ('.' )
360- module = model .get_submodule ('.' .join (parts [:- 1 ]))
361360 sub_module = model .get_submodule (module_key )
362- _key = parts [ - 1 ]
361+ lora_module = getattr ( sub_module , f'loramodule_ { adapter_name } ' )
363362
364- origin_module = None
365- if isinstance (sub_module , Linear ):
366- origin_module = torch .nn .Linear (
367- sub_module .in_features ,
368- sub_module .out_features ,
369- bias = hasattr (sub_module , 'bias' )
370- and sub_module .bias is not None ,
371- )
372- elif isinstance (sub_module , Embedding ):
373- origin_module = torch .nn .Embedding (
374- num_embeddings = sub_module .num_embeddings ,
375- embedding_dim = sub_module .embedding_dim ,
376- padding_idx = sub_module .padding_idx ,
377- max_norm = sub_module .max_norm ,
378- norm_type = sub_module .norm_type ,
379- scale_grad_by_freq = sub_module .scale_grad_by_freq ,
380- sparse = sub_module .sparse ,
381- )
382- elif isinstance (sub_module , Conv2d ):
383- origin_module = torch .nn .Conv2d (
384- sub_module .in_channels ,
385- sub_module .out_channels ,
386- kernel_size = sub_module .kernel_size ,
387- stride = sub_module .stride ,
388- padding = sub_module .padding ,
389- dilation = sub_module .dilation ,
390- groups = sub_module .groups )
391-
392- if origin_module is not None :
393- sub_module .merge_weights = True
394- sub_module .eval ()
395- origin_module .weight = sub_module .weight
396- if getattr (sub_module , 'bias' , None ) is not None :
397- origin_module .bias = sub_module .bias
398- origin_module .to (sub_module .weight .device ).to (
399- sub_module .weight .dtype )
400- setattr (module , _key , origin_module )
363+ if lora_module is not None :
364+ if hasattr (lora_module , 'merge_weights' ):
365+ lora_module .merge_weights = True
366+ lora_module .eval ()
367+ delattr (sub_module , f'loramodule_{ adapter_name } ' )
401368
402369
403370class LoRALayer (ActivationMixin ):
0 commit comments