@@ -77,40 +77,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional
7777 return self
7878
7979
80- def _get_target (* args , ** kwargs ):
81- target = None
82- if 'target' in kwargs :
83- target = kwargs ['target' ]
84- else :
85- for arg in args :
86- if isinstance (arg , torch .nn .Module ):
87- target = arg
88- break
89- return target
90-
91-
92- def _create_and_replace_hook (self , * args , ** kwargs ):
93- target = _get_target (* args , ** kwargs )
94- if target and target .__class__ .__name__ == 'NonDynamicallyQuantizableLinear' :
95- return
96-
97- return self ._create_and_replace_origin (* args , ** kwargs )
98-
99-
100- def _create_and_replace_hook2 (self , * args , ** kwargs ):
101- target = _get_target (* args , ** kwargs )
102-
80+ def _create_and_replace_hook (self , peft_config , adapter_name , target , * args , ** kwargs ):
10381 all_supported_names = ('linear' , )
10482 all_supported_types = (torch .nn .Embedding , torch .nn .Conv2d , transformers .pytorch_utils .Conv1D )
83+ target_modules = getattr (peft_config , 'target_modules' , None )
84+ if target is None :
85+ return
10586
106- is_multimodal = getattr (self .model , 'is_multimodal' , False )
107-
108- if is_multimodal and target is not None and (not any (
87+ if isinstance (target_modules , str ) and not any (
10988 [name in target .__class__ .__name__ .lower ()
110- for name in all_supported_names ]) and not any ([isinstance (target , type ) for type in all_supported_types ]) ):
89+ for name in all_supported_names ]) and not any ([isinstance (target , type_ ) for type_ in all_supported_types ]):
11190 return
11291
113- return _create_and_replace_hook (self , * args , ** kwargs )
92+ if target .__class__ .__name__ == 'NonDynamicallyQuantizableLinear' :
93+ return
94+
95+ return self ._create_and_replace_origin (peft_config , adapter_name , target , * args , ** kwargs )
11496
11597
11698def _convert_dtype (target : torch .nn .Module , adapter_name : str , lora_dtype : str ):
@@ -296,28 +278,24 @@ def keep_device_forward(self, *args, **kwargs):
296278
297279def hot_patch_peft_module ():
298280 from peft .tuners .lora import LoraLayer
281+ if hasattr ('LoraModel' , '_create_and_replace_origin' ):
282+ return
299283
300284 # Fix Lora does not support NonDynamicallyQuantizableLinear
301285 LoraModel ._create_and_replace_origin = LoraModel ._create_and_replace
302286 LoraModel ._create_and_replace = _create_and_replace_hook
303287 VeraModel ._create_and_replace_origin = VeraModel ._create_and_replace
304- VeraModel ._create_and_replace = _create_and_replace_hook2
288+ VeraModel ._create_and_replace = _create_and_replace_hook
305289 BOFTModel ._create_and_replace_origin = BOFTModel ._create_and_replace
306- BOFTModel ._create_and_replace = _create_and_replace_hook2
290+ BOFTModel ._create_and_replace = _create_and_replace_hook
307291 IA3Model ._create_and_replace_origin = IA3Model ._create_and_replace
308- IA3Model ._create_and_replace = _create_and_replace_hook2
292+ IA3Model ._create_and_replace = _create_and_replace_hook
309293 if FourierFTModel is not None :
310294 FourierFTModel ._create_and_replace_origin = FourierFTModel ._create_and_replace
311- FourierFTModel ._create_and_replace = _create_and_replace_hook2
295+ FourierFTModel ._create_and_replace = _create_and_replace_hook
312296
313297 # Support type conversion
314298 def init (self , model : torch .nn .Module , config : Dict [str , LoraConfig ], adapter_name ):
315- if isinstance (config , dict ):
316- for _config in config .values (): # There is a target_modules as a string.
317- if isinstance (getattr (_config , 'target_modules' , None ), str ):
318- # Make sure the regex can find all linear in the module.
319- LoraModel ._create_and_replace = _create_and_replace_hook2
320- break
321299
322300 self .__init_origin__ (model , config , adapter_name )
323301 if isinstance (self .active_adapter , list ):
0 commit comments