@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
339339 # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340340 # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341341 # their prefixes.
342- keys = list (state_dict .keys ())
343342 prefix = text_encoder_name if prefix is None else prefix
344343
345- # Safe prefix to check with.
346- if any (text_encoder_name in key for key in keys ):
347- # Load the layers corresponding to text encoder and make necessary adjustments.
348- text_encoder_keys = [k for k in keys if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
349- text_encoder_lora_state_dict = {
350- k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
351- }
344+ # Load the layers corresponding to text encoder and make necessary adjustments.
345+ if prefix is not None :
346+ state_dict = {k [len (f"{ prefix } ." ) :]: v for k , v in state_dict .items () if k .startswith (f"{ prefix } ." )}
347+
348+ if len (state_dict ) > 0 :
349+ logger .info (f"Loading { prefix } ." )
350+ rank = {}
351+ state_dict = convert_state_dict_to_diffusers (state_dict )
352+
353+ # convert state dict
354+ state_dict = convert_state_dict_to_peft (state_dict )
355+
356+ for name , _ in text_encoder_attn_modules (text_encoder ):
357+ for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
358+ rank_key = f"{ name } .{ module } .lora_B.weight"
359+ if rank_key not in state_dict :
360+ continue
361+ rank [rank_key ] = state_dict [rank_key ].shape [1 ]
362+
363+ for name , _ in text_encoder_mlp_modules (text_encoder ):
364+ for module in ("fc1" , "fc2" ):
365+ rank_key = f"{ name } .{ module } .lora_B.weight"
366+ if rank_key not in state_dict :
367+ continue
368+ rank [rank_key ] = state_dict [rank_key ].shape [1 ]
369+
370+ if network_alphas is not None :
371+ alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
372+ network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
373+
374+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
375+
376+ if "use_dora" in lora_config_kwargs :
377+ if lora_config_kwargs ["use_dora" ]:
378+ if is_peft_version ("<" , "0.9.0" ):
379+ raise ValueError (
380+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
381+ )
382+ else :
383+ if is_peft_version ("<" , "0.9.0" ):
384+ lora_config_kwargs .pop ("use_dora" )
385+
386+ if "lora_bias" in lora_config_kwargs :
387+ if lora_config_kwargs ["lora_bias" ]:
388+ if is_peft_version ("<=" , "0.13.2" ):
389+ raise ValueError (
390+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
391+ )
392+ else :
393+ if is_peft_version ("<=" , "0.13.2" ):
394+ lora_config_kwargs .pop ("lora_bias" )
352395
353- if len (text_encoder_lora_state_dict ) > 0 :
354- logger .info (f"Loading { prefix } ." )
355- rank = {}
356- text_encoder_lora_state_dict = convert_state_dict_to_diffusers (text_encoder_lora_state_dict )
357-
358- # convert state dict
359- text_encoder_lora_state_dict = convert_state_dict_to_peft (text_encoder_lora_state_dict )
360-
361- for name , _ in text_encoder_attn_modules (text_encoder ):
362- for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
363- rank_key = f"{ name } .{ module } .lora_B.weight"
364- if rank_key not in text_encoder_lora_state_dict :
365- continue
366- rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
367-
368- for name , _ in text_encoder_mlp_modules (text_encoder ):
369- for module in ("fc1" , "fc2" ):
370- rank_key = f"{ name } .{ module } .lora_B.weight"
371- if rank_key not in text_encoder_lora_state_dict :
372- continue
373- rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
374-
375- if network_alphas is not None :
376- alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
377- network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
378-
379- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , text_encoder_lora_state_dict , is_unet = False )
380-
381- if "use_dora" in lora_config_kwargs :
382- if lora_config_kwargs ["use_dora" ]:
383- if is_peft_version ("<" , "0.9.0" ):
384- raise ValueError (
385- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386- )
387- else :
388- if is_peft_version ("<" , "0.9.0" ):
389- lora_config_kwargs .pop ("use_dora" )
390-
391- if "lora_bias" in lora_config_kwargs :
392- if lora_config_kwargs ["lora_bias" ]:
393- if is_peft_version ("<=" , "0.13.2" ):
394- raise ValueError (
395- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396- )
397- else :
398- if is_peft_version ("<=" , "0.13.2" ):
399- lora_config_kwargs .pop ("lora_bias" )
396+ lora_config = LoraConfig (** lora_config_kwargs )
400397
401- lora_config = LoraConfig (** lora_config_kwargs )
398+ # adapter_name
399+ if adapter_name is None :
400+ adapter_name = get_adapter_name (text_encoder )
402401
403- # adapter_name
404- if adapter_name is None :
405- adapter_name = get_adapter_name (text_encoder )
402+ is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
406403
407- is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
404+ # inject LoRA layers and load the state dict
405+ # in transformers we automatically check whether the adapter name is already in use or not
406+ text_encoder .load_adapter (
407+ adapter_name = adapter_name ,
408+ adapter_state_dict = state_dict ,
409+ peft_config = lora_config ,
410+ ** peft_kwargs ,
411+ )
408412
409- # inject LoRA layers and load the state dict
410- # in transformers we automatically check whether the adapter name is already in use or not
411- text_encoder .load_adapter (
412- adapter_name = adapter_name ,
413- adapter_state_dict = text_encoder_lora_state_dict ,
414- peft_config = lora_config ,
415- ** peft_kwargs ,
416- )
413+ # scale LoRA layers with `lora_scale`
414+ scale_lora_layers (text_encoder , weight = lora_scale )
417415
418- # scale LoRA layers with `lora_scale`
419- scale_lora_layers (text_encoder , weight = lora_scale )
416+ text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
420417
421- text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
418+ # Offload back.
419+ if is_model_cpu_offload :
420+ _pipeline .enable_model_cpu_offload ()
421+ elif is_sequential_cpu_offload :
422+ _pipeline .enable_sequential_cpu_offload ()
423+ # Unsafe code />
422424
423- # Offload back.
424- if is_model_cpu_offload :
425- _pipeline .enable_model_cpu_offload ()
426- elif is_sequential_cpu_offload :
427- _pipeline .enable_sequential_cpu_offload ()
428- # Unsafe code />
425+ if prefix is not None and not state_dict :
426+ logger .info (
427+ f"No LoRA keys associated to { text_encoder .__class__ .__name__ } found with the { prefix = } . This is safe to ignore if LoRA state dict didn't originally have any { text_encoder .__class__ .__name__ } related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
428+ )
429429
430430
431431def _func_optionally_disable_offloading (_pipeline ):
0 commit comments