@@ -1830,7 +1830,7 @@ def load_lora_into_transformer(
18301830 The value of the network alpha used for stable learning and preventing underflow. This value has the
18311831 same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
18321832 link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1833- transformer (`SD3Transformer2DModel `):
1833+ transformer (`FluxTransformer2DModel `):
18341834 The Transformer model to load the LoRA layers into.
18351835 adapter_name (`str`, *optional*):
18361836 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -2118,7 +2118,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
21182118 text_encoder_name = TEXT_ENCODER_NAME
21192119
21202120 @classmethod
2121- def load_lora_into_transformer (cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None ):
2121+ # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2122+ def load_lora_into_transformer (
2123+ cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
2124+ ):
21222125 """
21232126 This will load the LoRA layers specified in `state_dict` into `transformer`.
21242127
@@ -2131,93 +2134,29 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
21312134 The value of the network alpha used for stable learning and preventing underflow. This value has the
21322135 same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
21332136 link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2134- unet (`UNet2DConditionModel `):
2135- The UNet model to load the LoRA layers into.
2137+ transformer (`UVit2DModel `):
2138+ The Transformer model to load the LoRA layers into.
21362139 adapter_name (`str`, *optional*):
21372140 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
21382141 `default_{i}` where i is the total number of adapters being loaded.
2142+ low_cpu_mem_usage (`bool`, *optional*):
2143+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2144+ weights.
21392145 """
2140- if not USE_PEFT_BACKEND :
2141- raise ValueError ("PEFT backend is required for this method." )
2142-
2143- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
2144-
2145- keys = list (state_dict .keys ())
2146-
2147- transformer_keys = [k for k in keys if k .startswith (cls .transformer_name )]
2148- state_dict = {
2149- k .replace (f"{ cls .transformer_name } ." , "" ): v for k , v in state_dict .items () if k in transformer_keys
2150- }
2151-
2152- if network_alphas is not None :
2153- alpha_keys = [k for k in network_alphas .keys () if k .startswith (cls .transformer_name )]
2154- network_alphas = {
2155- k .replace (f"{ cls .transformer_name } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys
2156- }
2157-
2158- if len (state_dict .keys ()) > 0 :
2159- if adapter_name in getattr (transformer , "peft_config" , {}):
2160- raise ValueError (
2161- f"Adapter name { adapter_name } already in use in the transformer - please select a new adapter name."
2162- )
2163-
2164- rank = {}
2165- for key , val in state_dict .items ():
2166- if "lora_B" in key :
2167- rank [key ] = val .shape [1 ]
2146+ if low_cpu_mem_usage and not is_peft_version (">=" , "0.13.1" ):
2147+ raise ValueError (
2148+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2149+ )
21682150
2169- lora_config_kwargs = get_peft_kwargs (rank , network_alphas , state_dict )
2170- if "use_dora" in lora_config_kwargs :
2171- if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
2172- raise ValueError (
2173- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2174- )
2175- else :
2176- lora_config_kwargs .pop ("use_dora" )
2177- lora_config = LoraConfig (** lora_config_kwargs )
2178-
2179- # adapter_name
2180- if adapter_name is None :
2181- adapter_name = get_adapter_name (transformer )
2182-
2183- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2184- # otherwise loading LoRA weights will lead to an error
2185- is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
2186-
2187- inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
2188- incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name )
2189-
2190- warn_msg = ""
2191- if incompatible_keys is not None :
2192- # Check only for unexpected keys.
2193- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
2194- if unexpected_keys :
2195- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
2196- if lora_unexpected_keys :
2197- warn_msg = (
2198- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2199- f" { ', ' .join (lora_unexpected_keys )} . "
2200- )
2201-
2202- # Filter missing keys specific to the current adapter.
2203- missing_keys = getattr (incompatible_keys , "missing_keys" , None )
2204- if missing_keys :
2205- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
2206- if lora_missing_keys :
2207- warn_msg += (
2208- f"Loading adapter weights from state_dict led to missing keys in the model:"
2209- f" { ', ' .join (lora_missing_keys )} ."
2210- )
2211-
2212- if warn_msg :
2213- logger .warning (warn_msg )
2214-
2215- # Offload back.
2216- if is_model_cpu_offload :
2217- _pipeline .enable_model_cpu_offload ()
2218- elif is_sequential_cpu_offload :
2219- _pipeline .enable_sequential_cpu_offload ()
2220- # Unsafe code />
2151+ # Load the layers corresponding to transformer.
2152+ logger .info (f"Loading { cls .transformer_name } ." )
2153+ transformer .load_lora_adapter (
2154+ state_dict ,
2155+ network_alphas = network_alphas ,
2156+ adapter_name = adapter_name ,
2157+ _pipeline = _pipeline ,
2158+ low_cpu_mem_usage = low_cpu_mem_usage ,
2159+ )
22212160
22222161 @classmethod
22232162 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
0 commit comments