Skip to content

Commit b6c0a58

Browse files
authored
fix peft is_multimodal (#2462)
1 parent 10a1143 commit b6c0a58

File tree

1 file changed

+16
-38
lines changed

1 file changed

+16
-38
lines changed

swift/tuners/peft.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -77,40 +77,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional
7777
return self
7878

7979

80-
def _get_target(*args, **kwargs):
81-
target = None
82-
if 'target' in kwargs:
83-
target = kwargs['target']
84-
else:
85-
for arg in args:
86-
if isinstance(arg, torch.nn.Module):
87-
target = arg
88-
break
89-
return target
90-
91-
92-
def _create_and_replace_hook(self, *args, **kwargs):
93-
target = _get_target(*args, **kwargs)
94-
if target and target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
95-
return
96-
97-
return self._create_and_replace_origin(*args, **kwargs)
98-
99-
100-
def _create_and_replace_hook2(self, *args, **kwargs):
101-
target = _get_target(*args, **kwargs)
102-
80+
def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
10381
all_supported_names = ('linear', )
10482
all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D)
83+
target_modules = getattr(peft_config, 'target_modules', None)
84+
if target is None:
85+
return
10586

106-
is_multimodal = getattr(self.model, 'is_multimodal', False)
107-
108-
if is_multimodal and target is not None and (not any(
87+
if isinstance(target_modules, str) and not any(
10988
[name in target.__class__.__name__.lower()
110-
for name in all_supported_names]) and not any([isinstance(target, type) for type in all_supported_types])):
89+
for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
11190
return
11291

113-
return _create_and_replace_hook(self, *args, **kwargs)
92+
if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
93+
return
94+
95+
return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs)
11496

11597

11698
def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str):
@@ -296,28 +278,24 @@ def keep_device_forward(self, *args, **kwargs):
296278

297279
def hot_patch_peft_module():
298280
from peft.tuners.lora import LoraLayer
281+
if hasattr('LoraModel', '_create_and_replace_origin'):
282+
return
299283

300284
# Fix Lora does not support NonDynamicallyQuantizableLinear
301285
LoraModel._create_and_replace_origin = LoraModel._create_and_replace
302286
LoraModel._create_and_replace = _create_and_replace_hook
303287
VeraModel._create_and_replace_origin = VeraModel._create_and_replace
304-
VeraModel._create_and_replace = _create_and_replace_hook2
288+
VeraModel._create_and_replace = _create_and_replace_hook
305289
BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
306-
BOFTModel._create_and_replace = _create_and_replace_hook2
290+
BOFTModel._create_and_replace = _create_and_replace_hook
307291
IA3Model._create_and_replace_origin = IA3Model._create_and_replace
308-
IA3Model._create_and_replace = _create_and_replace_hook2
292+
IA3Model._create_and_replace = _create_and_replace_hook
309293
if FourierFTModel is not None:
310294
FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
311-
FourierFTModel._create_and_replace = _create_and_replace_hook2
295+
FourierFTModel._create_and_replace = _create_and_replace_hook
312296

313297
# Support type conversion
314298
def init(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name):
315-
if isinstance(config, dict):
316-
for _config in config.values(): # There is a target_modules as a string.
317-
if isinstance(getattr(_config, 'target_modules', None), str):
318-
# Make sure the regex can find all linear in the module.
319-
LoraModel._create_and_replace = _create_and_replace_hook2
320-
break
321299

322300
self.__init_origin__(model, config, adapter_name)
323301
if isinstance(self.active_adapter, list):

0 commit comments

Comments
 (0)