2929 convert_unet_state_dict_to_peft ,
3030 delete_adapter_layers ,
3131 get_adapter_name ,
32- get_peft_kwargs ,
3332 is_peft_available ,
3433 is_peft_version ,
3534 logging ,
3635 set_adapter_layers ,
3736 set_weights_and_activate_adapters ,
3837)
38+ from ..utils .peft_utils import (
39+ _create_lora_config ,
40+ _lora_loading_context ,
41+ _maybe_warn_for_unhandled_keys ,
42+ _maybe_warn_if_no_keys_found ,
43+ )
3944from .lora_base import _fetch_state_dict , _func_optionally_disable_offloading
4045from .unet_loader_utils import _maybe_expand_lora_scales
4146
6469}
6570
6671
67- def _maybe_raise_error_for_ambiguity (config ):
68- rank_pattern = config ["rank_pattern" ].copy ()
69- target_modules = config ["target_modules" ]
70-
71- for key in list (rank_pattern .keys ()):
72- # try to detect ambiguity
73- # `target_modules` can also be a str, in which case this loop would loop
74- # over the chars of the str. The technically correct way to match LoRA keys
75- # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
76- # But this cuts it for now.
77- exact_matches = [mod for mod in target_modules if mod == key ]
78- substring_matches = [mod for mod in target_modules if key in mod and mod != key ]
79-
80- if exact_matches and substring_matches :
81- if is_peft_version ("<" , "0.14.1" ):
82- raise ValueError (
83- "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
84- )
85-
86-
8772class PeftAdapterMixin :
8873 """
8974 A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -189,7 +174,7 @@ def load_lora_adapter(
189174 https://huggingface.co/docs/peft/main/en/package_reference/hotswap
190175 metadata: TODO
191176 """
192- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
177+ from peft import inject_adapter_in_model , set_peft_model_state_dict
193178 from peft .tuners .tuners_utils import BaseTunerLayer
194179
195180 cache_dir = kwargs .pop ("cache_dir" , None )
@@ -214,7 +199,6 @@ def load_lora_adapter(
214199 )
215200
216201 user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
217-
218202 state_dict , metadata = _fetch_state_dict (
219203 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
220204 weight_name = weight_name ,
@@ -273,38 +257,8 @@ def load_lora_adapter(
273257 k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys
274258 }
275259
276- if metadata is not None :
277- lora_config_kwargs = metadata
278- else :
279- lora_config_kwargs = get_peft_kwargs (
280- rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict
281- )
282- _maybe_raise_error_for_ambiguity (lora_config_kwargs )
283-
284- if "use_dora" in lora_config_kwargs :
285- if lora_config_kwargs ["use_dora" ]:
286- if is_peft_version ("<" , "0.9.0" ):
287- raise ValueError (
288- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
289- )
290- else :
291- if is_peft_version ("<" , "0.9.0" ):
292- lora_config_kwargs .pop ("use_dora" )
293-
294- if "lora_bias" in lora_config_kwargs :
295- if lora_config_kwargs ["lora_bias" ]:
296- if is_peft_version ("<=" , "0.13.2" ):
297- raise ValueError (
298- "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
299- )
300- else :
301- if is_peft_version ("<=" , "0.13.2" ):
302- lora_config_kwargs .pop ("lora_bias" )
303-
304- try :
305- lora_config = LoraConfig (** lora_config_kwargs )
306- except TypeError as e :
307- raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
260+ # create LoraConfig
261+ lora_config = _create_lora_config (state_dict , network_alphas , metadata , rank , self .lora_layer_modules )
308262
309263 # adapter_name
310264 if adapter_name is None :
@@ -315,132 +269,98 @@ def load_lora_adapter(
315269 # Now we remove any existing hooks to `_pipeline`.
316270
317271 # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
318- # otherwise loading LoRA weights will lead to an error
319- is_model_cpu_offload , is_sequential_cpu_offload = self ._optionally_disable_offloading (_pipeline )
320-
321- peft_kwargs = {}
322- if is_peft_version (">=" , "0.13.1" ):
323- peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
324-
325- if hotswap or (self ._prepare_lora_hotswap_kwargs is not None ):
326- if is_peft_version (">" , "0.14.0" ):
327- from peft .utils .hotswap import (
328- check_hotswap_configs_compatible ,
329- hotswap_adapter_from_state_dict ,
330- prepare_model_for_compiled_hotswap ,
331- )
332- else :
333- msg = (
334- "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
335- "from source."
336- )
337- raise ImportError (msg )
338-
339- if hotswap :
340-
341- def map_state_dict_for_hotswap (sd ):
342- # For hotswapping, we need the adapter name to be present in the state dict keys
343- new_sd = {}
344- for k , v in sd .items ():
345- if k .endswith ("lora_A.weight" ) or key .endswith ("lora_B.weight" ):
346- k = k [: - len (".weight" )] + f".{ adapter_name } .weight"
347- elif k .endswith ("lora_B.bias" ): # lora_bias=True option
348- k = k [: - len (".bias" )] + f".{ adapter_name } .bias"
349- new_sd [k ] = v
350- return new_sd
351-
352- # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
353- # we should also delete the `peft_config` associated to the `adapter_name`.
354- try :
355- if hotswap :
356- state_dict = map_state_dict_for_hotswap (state_dict )
357- check_hotswap_configs_compatible (self .peft_config [adapter_name ], lora_config )
358- try :
359- hotswap_adapter_from_state_dict (
360- model = self ,
361- state_dict = state_dict ,
362- adapter_name = adapter_name ,
363- config = lora_config ,
272+ # otherwise loading LoRA weights will lead to an error. So, we use a context manager here
273+ # that takes care of enabling and disabling offloading in the pipeline automatically.
274+ with _lora_loading_context (_pipeline ):
275+ peft_kwargs = {}
276+ if is_peft_version (">=" , "0.13.1" ):
277+ peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
278+
279+ if hotswap or (self ._prepare_lora_hotswap_kwargs is not None ):
280+ if is_peft_version (">" , "0.14.0" ):
281+ from peft .utils .hotswap import (
282+ check_hotswap_configs_compatible ,
283+ hotswap_adapter_from_state_dict ,
284+ prepare_model_for_compiled_hotswap ,
364285 )
365- except Exception as e :
366- logger .error (f"Hotswapping { adapter_name } was unsuccessful with the following error: \n { e } " )
367- raise
368- # the hotswap function raises if there are incompatible keys, so if we reach this point we can set
369- # it to None
370- incompatible_keys = None
371- else :
372- inject_adapter_in_model (lora_config , self , adapter_name = adapter_name , ** peft_kwargs )
373- incompatible_keys = set_peft_model_state_dict (self , state_dict , adapter_name , ** peft_kwargs )
374-
375- if self ._prepare_lora_hotswap_kwargs is not None :
376- # For hotswapping of compiled models or adapters with different ranks.
377- # If the user called enable_lora_hotswap, we need to ensure it is called:
378- # - after the first adapter was loaded
379- # - before the model is compiled and the 2nd adapter is being hotswapped in
380- # Therefore, it needs to be called here
381- prepare_model_for_compiled_hotswap (
382- self , config = lora_config , ** self ._prepare_lora_hotswap_kwargs
383- )
384- # We only want to call prepare_model_for_compiled_hotswap once
385- self ._prepare_lora_hotswap_kwargs = None
386-
387- # Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
388- if not self ._hf_peft_config_loaded :
389- self ._hf_peft_config_loaded = True
390- except Exception as e :
391- # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
392- if hasattr (self , "peft_config" ):
393- for module in self .modules ():
394- if isinstance (module , BaseTunerLayer ):
395- active_adapters = module .active_adapters
396- for active_adapter in active_adapters :
397- if adapter_name in active_adapter :
398- module .delete_adapter (adapter_name )
399-
400- self .peft_config .pop (adapter_name )
401- logger .error (f"Loading { adapter_name } was unsuccessful with the following error: \n { e } " )
402- raise
403-
404- warn_msg = ""
405- if incompatible_keys is not None :
406- # Check only for unexpected keys.
407- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
408- if unexpected_keys :
409- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
410- if lora_unexpected_keys :
411- warn_msg = (
412- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
413- f" { ', ' .join (lora_unexpected_keys )} . "
414- )
415-
416- # Filter missing keys specific to the current adapter.
417- missing_keys = getattr (incompatible_keys , "missing_keys" , None )
418- if missing_keys :
419- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
420- if lora_missing_keys :
421- warn_msg += (
422- f"Loading adapter weights from state_dict led to missing keys in the model:"
423- f" { ', ' .join (lora_missing_keys )} ."
286+ else :
287+ msg = (
288+ "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
289+ "from source."
424290 )
291+ raise ImportError (msg )
425292
426- if warn_msg :
427- logger .warning (warn_msg )
293+ if hotswap :
428294
429- # Offload back.
430- if is_model_cpu_offload :
431- _pipeline .enable_model_cpu_offload ()
432- elif is_sequential_cpu_offload :
433- _pipeline .enable_sequential_cpu_offload ()
295+ def map_state_dict_for_hotswap (sd ):
296+ # For hotswapping, we need the adapter name to be present in the state dict keys
297+ new_sd = {}
298+ for k , v in sd .items ():
299+ if k .endswith ("lora_A.weight" ) or key .endswith ("lora_B.weight" ):
300+ k = k [: - len (".weight" )] + f".{ adapter_name } .weight"
301+ elif k .endswith ("lora_B.bias" ): # lora_bias=True option
302+ k = k [: - len (".bias" )] + f".{ adapter_name } .bias"
303+ new_sd [k ] = v
304+ return new_sd
305+
306+ # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
307+ # we should also delete the `peft_config` associated to the `adapter_name`.
308+ try :
309+ if hotswap :
310+ state_dict = map_state_dict_for_hotswap (state_dict )
311+ check_hotswap_configs_compatible (self .peft_config [adapter_name ], lora_config )
312+ try :
313+ hotswap_adapter_from_state_dict (
314+ model = self ,
315+ state_dict = state_dict ,
316+ adapter_name = adapter_name ,
317+ config = lora_config ,
318+ )
319+ except Exception as e :
320+ logger .error (
321+ f"Hotswapping { adapter_name } was unsuccessful with the following error: \n { e } "
322+ )
323+ raise
324+ # the hotswap function raises if there are incompatible keys, so if we reach this point we can set
325+ # it to None
326+ incompatible_keys = None
327+ else :
328+ inject_adapter_in_model (lora_config , self , adapter_name = adapter_name , ** peft_kwargs )
329+ incompatible_keys = set_peft_model_state_dict (self , state_dict , adapter_name , ** peft_kwargs )
330+
331+ if self ._prepare_lora_hotswap_kwargs is not None :
332+ # For hotswapping of compiled models or adapters with different ranks.
333+ # If the user called enable_lora_hotswap, we need to ensure it is called:
334+ # - after the first adapter was loaded
335+ # - before the model is compiled and the 2nd adapter is being hotswapped in
336+ # Therefore, it needs to be called here
337+ prepare_model_for_compiled_hotswap (
338+ self , config = lora_config , ** self ._prepare_lora_hotswap_kwargs
339+ )
340+ # We only want to call prepare_model_for_compiled_hotswap once
341+ self ._prepare_lora_hotswap_kwargs = None
342+
343+ # Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
344+ if not self ._hf_peft_config_loaded :
345+ self ._hf_peft_config_loaded = True
346+ except Exception as e :
347+ # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
348+ if hasattr (self , "peft_config" ):
349+ for module in self .modules ():
350+ if isinstance (module , BaseTunerLayer ):
351+ active_adapters = module .active_adapters
352+ for active_adapter in active_adapters :
353+ if adapter_name in active_adapter :
354+ module .delete_adapter (adapter_name )
355+
356+ self .peft_config .pop (adapter_name )
357+ logger .error (f"Loading { adapter_name } was unsuccessful with the following error: \n { e } " )
358+ raise
434359 # Unsafe code />
435360
436- if prefix is not None and not state_dict :
437- logger .warning (
438- f"No LoRA keys associated to { self .__class__ .__name__ } found with the { prefix = } . "
439- "This is safe to ignore if LoRA state dict didn't originally have any "
440- f"{ self .__class__ .__name__ } related params. You can also try specifying `prefix=None` "
441- "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
442- "https://github.com/huggingface/diffusers/issues/new"
443- )
361+ _maybe_warn_for_unhandled_keys (incompatible_keys , adapter_name )
362+
363+ _maybe_warn_if_no_keys_found (state_dict , prefix , model_class_name = self .__class__ .__name__ )
444364
445365 def save_lora_adapter (
446366 self ,
0 commit comments