3434 delete_adapter_layers ,
3535 deprecate ,
3636 get_adapter_name ,
37- get_peft_kwargs ,
3837 is_accelerate_available ,
3938 is_peft_available ,
4039 is_peft_version ,
4645 set_adapter_layers ,
4746 set_weights_and_activate_adapters ,
4847)
48+ from ..utils .peft_utils import _create_lora_config , _lora_loading_context , _maybe_warn_if_no_keys_found
4949from ..utils .state_dict_utils import _load_sft_state_dict_metadata
5050
5151
5252if is_transformers_available ():
5353 from transformers import PreTrainedModel
5454
55- from ..models .lora import text_encoder_attn_modules , text_encoder_mlp_modules
56-
5755if is_peft_available ():
5856 from peft .tuners .tuners_utils import BaseTunerLayer
5957
@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
352350 )
353351 peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
354352
355- from peft import LoraConfig
356-
357353 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
358354 # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
359355 # their prefixes.
@@ -377,89 +373,41 @@ def _load_lora_into_text_encoder(
377373 # convert state dict
378374 state_dict = convert_state_dict_to_peft (state_dict )
379375
380- for name , _ in text_encoder_attn_modules (text_encoder ):
381- for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
382- rank_key = f"{ name } .{ module } .lora_B.weight"
383- if rank_key not in state_dict :
384- continue
385- rank [rank_key ] = state_dict [rank_key ].shape [1 ]
386-
387- for name , _ in text_encoder_mlp_modules (text_encoder ):
388- for module in ("fc1" , "fc2" ):
389- rank_key = f"{ name } .{ module } .lora_B.weight"
390- if rank_key not in state_dict :
391- continue
392- rank [rank_key ] = state_dict [rank_key ].shape [1 ]
376+ for name , _ in text_encoder .named_modules ():
377+ if name .endswith ((".q_proj" , ".k_proj" , ".v_proj" , ".out_proj" , ".fc1" , ".fc2" )):
378+ rank_key = f"{ name } .lora_B.weight"
379+ if rank_key in state_dict :
380+ rank [rank_key ] = state_dict [rank_key ].shape [1 ]
393381
394382 if network_alphas is not None :
395383 alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
396384 network_alphas = {k .removeprefix (f"{ prefix } ." ): v for k , v in network_alphas .items () if k in alpha_keys }
397385
398- if metadata is not None :
399- lora_config_kwargs = metadata
400- else :
401- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
402-
403- if "use_dora" in lora_config_kwargs :
404- if lora_config_kwargs ["use_dora" ]:
405- if is_peft_version ("<" , "0.9.0" ):
406- raise ValueError (
407- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
408- )
409- else :
410- if is_peft_version ("<" , "0.9.0" ):
411- lora_config_kwargs .pop ("use_dora" )
412-
413- if "lora_bias" in lora_config_kwargs :
414- if lora_config_kwargs ["lora_bias" ]:
415- if is_peft_version ("<=" , "0.13.2" ):
416- raise ValueError (
417- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
418- )
419- else :
420- if is_peft_version ("<=" , "0.13.2" ):
421- lora_config_kwargs .pop ("lora_bias" )
422-
423- try :
424- lora_config = LoraConfig (** lora_config_kwargs )
425- except TypeError as e :
426- raise TypeError ("`LoraConfig` class could not be instantiated." ) from e
386+ # create `LoraConfig`
387+ lora_config = _create_lora_config (state_dict , network_alphas , metadata , rank , is_unet = False )
427388
428389 # adapter_name
429390 if adapter_name is None :
430391 adapter_name = get_adapter_name (text_encoder )
431392
432- is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
433-
434- # inject LoRA layers and load the state dict
435- # in transformers we automatically check whether the adapter name is already in use or not
436- text_encoder .load_adapter (
437- adapter_name = adapter_name ,
438- adapter_state_dict = state_dict ,
439- peft_config = lora_config ,
440- ** peft_kwargs ,
441- )
442-
443- # scale LoRA layers with `lora_scale`
444- scale_lora_layers (text_encoder , weight = lora_scale )
393+ # <Unsafe code
394+ with _lora_loading_context (_pipeline ):
395+ # inject LoRA layers and load the state dict
396+ # in transformers we automatically check whether the adapter name is already in use or not
397+ text_encoder .load_adapter (
398+ adapter_name = adapter_name ,
399+ adapter_state_dict = state_dict ,
400+ peft_config = lora_config ,
401+ ** peft_kwargs ,
402+ )
445403
446- text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
404+ # scale LoRA layers with `lora_scale`
405+ scale_lora_layers (text_encoder , weight = lora_scale )
447406
448- # Offload back.
449- if is_model_cpu_offload :
450- _pipeline .enable_model_cpu_offload ()
451- elif is_sequential_cpu_offload :
452- _pipeline .enable_sequential_cpu_offload ()
407+ text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
453408 # Unsafe code />
454409
455- if prefix is not None and not state_dict :
456- logger .warning (
457- f"No LoRA keys associated to { text_encoder .__class__ .__name__ } found with the { prefix = } . "
458- "This is safe to ignore if LoRA state dict didn't originally have any "
459- f"{ text_encoder .__class__ .__name__ } related params. You can also try specifying `prefix=None` "
460- "to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
461- "https://github.com/huggingface/diffusers/issues/new"
462- )
410+ _maybe_warn_if_no_keys_found (state_dict , prefix , model_class_name = text_encoder .__class__ .__name__ )
463411
464412
465413def _func_optionally_disable_offloading (_pipeline ):
0 commit comments