Skip to content

Commit c4c68ad

Browse files
fix unpatch_lora (#80)
1 parent ab286fa commit c4c68ad

File tree

3 files changed

+28
-56
lines changed

3 files changed

+28
-56
lines changed

swift/tuners/lora.py

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
165165
"""Prepare a model with `LoRAConfig`"""
166166
LoRA._dynamic_patch_lora(
167167
model,
168-
replace_modules=config.target_modules,
168+
target_modules=config.target_modules,
169169
r=config.r,
170170
adapter_name=adapter_name,
171171
lora_alpha=config.lora_alpha,
@@ -195,32 +195,32 @@ def activate_adapter(module: torch.nn.Module, adapter_name: str,
195195

196196
@staticmethod
197197
def _dynamic_patch_lora(model: torch.nn.Module,
198-
replace_modules: Union[str, List[str]],
198+
target_modules: Union[str, List[str]],
199199
use_merged_linear: bool, adapter_name: str,
200200
**kwargs):
201201
"""Dynamic patch lora to model
202202
203203
Args:
204204
model(`torch.nn.Module`): The torch.nn.Module containing the target module to be patched.
205-
replace_modules(`Union[str, List[str]]`): The module names to be replaced,
205+
target_modules(`Union[str, List[str]]`): The module names to be replaced,
206206
the replacing strategy is `end with`.
207207
use_merged_linear(bool): Whether to replace with merged linear layer.
208208
adapter_name(str): The adapter name.
209209
**kwargs: The arguments passed from `tune` which are needed by lora.
210210
"""
211211
modules = {}
212212
module_keys = [key for key, _ in model.named_modules()]
213-
assert isinstance(replace_modules, (str, list))
213+
assert isinstance(target_modules, (str, list))
214214
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(
215215
get_quantization_config(model, method='gptq'))
216216

217217
for module_key in module_keys:
218-
if isinstance(replace_modules, str):
219-
target_module_found = re.fullmatch(replace_modules, module_key)
218+
if isinstance(target_modules, str):
219+
target_module_found = re.fullmatch(target_modules, module_key)
220220
else:
221221
target_module_found = any(
222222
module_key.endswith(target_key)
223-
for target_key in replace_modules)
223+
for target_key in target_modules)
224224
if target_module_found: # noqa
225225
sub_module = model.get_submodule(module_key)
226226

@@ -333,71 +333,38 @@ def _forward(self, *args, **kwargs):
333333
logger.debug(f'Lora modules(module_key -> adapter_name): {modules}')
334334

335335
@staticmethod
336-
def unpatch_lora(model, config: LoRAConfig):
336+
def unpatch_lora(model, config: LoRAConfig, adapter_name: str):
337337
"""Unpatch lora modules and merge the weights to original modules.
338338
339339
LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
340340
'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
341341
See https://arxiv.org/abs/2106.09685
342342
343343
Args:
344-
model: The model called with `tune` function.
345-
config: The `LoRAConfig` to use.
344+
model(`torch.nn.Module`): The model called with `tune` function.
345+
config(`LoRAConfig`): The `LoRAConfig` to use.
346+
adapter_name(`str`): The adapter name
346347
"""
347348
module_keys = [key for key, _ in model.named_modules()]
348-
assert isinstance(config.replace_modules, (str, list))
349-
replace_modules = config.replace_modules
349+
assert isinstance(config.target_modules, (str, list))
350+
target_modules = config.target_modules
350351

351352
for module_key in module_keys:
352-
if isinstance(replace_modules, str):
353-
target_module_found = re.fullmatch(replace_modules, module_key)
353+
if isinstance(target_modules, str):
354+
target_module_found = re.fullmatch(target_modules, module_key)
354355
else:
355356
target_module_found = any(
356357
module_key.endswith(target_key)
357-
for target_key in replace_modules)
358+
for target_key in target_modules)
358359
if target_module_found: # noqa
359-
parts = module_key.split('.')
360-
module = model.get_submodule('.'.join(parts[:-1]))
361360
sub_module = model.get_submodule(module_key)
362-
_key = parts[-1]
361+
lora_module = getattr(sub_module, f'loramodule_{adapter_name}')
363362

364-
origin_module = None
365-
if isinstance(sub_module, Linear):
366-
origin_module = torch.nn.Linear(
367-
sub_module.in_features,
368-
sub_module.out_features,
369-
bias=hasattr(sub_module, 'bias')
370-
and sub_module.bias is not None,
371-
)
372-
elif isinstance(sub_module, Embedding):
373-
origin_module = torch.nn.Embedding(
374-
num_embeddings=sub_module.num_embeddings,
375-
embedding_dim=sub_module.embedding_dim,
376-
padding_idx=sub_module.padding_idx,
377-
max_norm=sub_module.max_norm,
378-
norm_type=sub_module.norm_type,
379-
scale_grad_by_freq=sub_module.scale_grad_by_freq,
380-
sparse=sub_module.sparse,
381-
)
382-
elif isinstance(sub_module, Conv2d):
383-
origin_module = torch.nn.Conv2d(
384-
sub_module.in_channels,
385-
sub_module.out_channels,
386-
kernel_size=sub_module.kernel_size,
387-
stride=sub_module.stride,
388-
padding=sub_module.padding,
389-
dilation=sub_module.dilation,
390-
groups=sub_module.groups)
391-
392-
if origin_module is not None:
393-
sub_module.merge_weights = True
394-
sub_module.eval()
395-
origin_module.weight = sub_module.weight
396-
if getattr(sub_module, 'bias', None) is not None:
397-
origin_module.bias = sub_module.bias
398-
origin_module.to(sub_module.weight.device).to(
399-
sub_module.weight.dtype)
400-
setattr(module, _key, origin_module)
363+
if lora_module is not None:
364+
if hasattr(lora_module, 'merge_weights'):
365+
lora_module.merge_weights = True
366+
lora_module.eval()
367+
delattr(sub_module, f'loramodule_{adapter_name}')
401368

402369

403370
class LoRALayer(ActivationMixin):

swift/tuners/side.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import List, Union
1010

1111
import torch
12-
import torchvision
1312
from torch import nn
1413

1514
from swift.utils.logger import get_logger
@@ -174,6 +173,7 @@ def __init__(self, dim, side_module_name='fcn4'):
174173
elif side_module_name == 'mlp':
175174
self.side_net = Mlp(dim)
176175
elif side_module_name == 'alexnet':
176+
import torchvision
177177
mm = torchvision.models.alexnet(pretrained=True)
178178
self.side_net = nn.Sequential(
179179
OrderedDict([('features', mm.features),

tests/tuners/test_swift_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from swift import (AdapterConfig, LoRAConfig, PromptConfig, ResTuningConfig,
1818
SideConfig, Swift, SwiftModel, push_to_hub)
19+
from swift.tuners import LoRA
1920

2021

2122
class TestSwift(unittest.TestCase):
@@ -182,6 +183,10 @@ def reset_parameters(self):
182183
torch.isclose(state_dict[key],
183184
state_dict2[key]).flatten().detach().cpu()))
184185

186+
LoRA.unpatch_lora(model2, lora_config, 'default')
187+
output3 = model2(**input)
188+
self.assertTrue(torch.allclose(output1.logits, output3.logits))
189+
185190
def test_swift_multiple_adapters(self):
186191
model = SbertForSequenceClassification(SbertConfig())
187192
model2 = copy.deepcopy(model)

0 commit comments

Comments
 (0)