Skip to content

Commit 2e70a93

Browse files
committed
updates
1 parent 984b8c9 commit 2e70a93

File tree

1 file changed

+23
-84
lines changed

1 file changed

+23
-84
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 23 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,7 +1830,7 @@ def load_lora_into_transformer(
18301830
The value of the network alpha used for stable learning and preventing underflow. This value has the
18311831
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
18321832
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1833-
transformer (`SD3Transformer2DModel`):
1833+
transformer (`FluxTransformer2DModel`):
18341834
The Transformer model to load the LoRA layers into.
18351835
adapter_name (`str`, *optional*):
18361836
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -2118,7 +2118,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
21182118
text_encoder_name = TEXT_ENCODER_NAME
21192119

21202120
@classmethod
2121-
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
2121+
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2122+
def load_lora_into_transformer(
2123+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
2124+
):
21222125
"""
21232126
This will load the LoRA layers specified in `state_dict` into `transformer`.
21242127
@@ -2131,93 +2134,29 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
21312134
The value of the network alpha used for stable learning and preventing underflow. This value has the
21322135
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
21332136
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
2134-
unet (`UNet2DConditionModel`):
2135-
The UNet model to load the LoRA layers into.
2137+
transformer (`UVit2DModel`):
2138+
The Transformer model to load the LoRA layers into.
21362139
adapter_name (`str`, *optional*):
21372140
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
21382141
`default_{i}` where i is the total number of adapters being loaded.
2142+
low_cpu_mem_usage (`bool`, *optional*):
2143+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
2144+
weights.
21392145
"""
2140-
if not USE_PEFT_BACKEND:
2141-
raise ValueError("PEFT backend is required for this method.")
2142-
2143-
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2144-
2145-
keys = list(state_dict.keys())
2146-
2147-
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2148-
state_dict = {
2149-
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2150-
}
2151-
2152-
if network_alphas is not None:
2153-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
2154-
network_alphas = {
2155-
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2156-
}
2157-
2158-
if len(state_dict.keys()) > 0:
2159-
if adapter_name in getattr(transformer, "peft_config", {}):
2160-
raise ValueError(
2161-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2162-
)
2163-
2164-
rank = {}
2165-
for key, val in state_dict.items():
2166-
if "lora_B" in key:
2167-
rank[key] = val.shape[1]
2146+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
2147+
raise ValueError(
2148+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
2149+
)
21682150

2169-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
2170-
if "use_dora" in lora_config_kwargs:
2171-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2172-
raise ValueError(
2173-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2174-
)
2175-
else:
2176-
lora_config_kwargs.pop("use_dora")
2177-
lora_config = LoraConfig(**lora_config_kwargs)
2178-
2179-
# adapter_name
2180-
if adapter_name is None:
2181-
adapter_name = get_adapter_name(transformer)
2182-
2183-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2184-
# otherwise loading LoRA weights will lead to an error
2185-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2186-
2187-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
2188-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
2189-
2190-
warn_msg = ""
2191-
if incompatible_keys is not None:
2192-
# Check only for unexpected keys.
2193-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2194-
if unexpected_keys:
2195-
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
2196-
if lora_unexpected_keys:
2197-
warn_msg = (
2198-
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2199-
f" {', '.join(lora_unexpected_keys)}. "
2200-
)
2201-
2202-
# Filter missing keys specific to the current adapter.
2203-
missing_keys = getattr(incompatible_keys, "missing_keys", None)
2204-
if missing_keys:
2205-
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
2206-
if lora_missing_keys:
2207-
warn_msg += (
2208-
f"Loading adapter weights from state_dict led to missing keys in the model:"
2209-
f" {', '.join(lora_missing_keys)}."
2210-
)
2211-
2212-
if warn_msg:
2213-
logger.warning(warn_msg)
2214-
2215-
# Offload back.
2216-
if is_model_cpu_offload:
2217-
_pipeline.enable_model_cpu_offload()
2218-
elif is_sequential_cpu_offload:
2219-
_pipeline.enable_sequential_cpu_offload()
2220-
# Unsafe code />
2151+
# Load the layers corresponding to transformer.
2152+
logger.info(f"Loading {cls.transformer_name}.")
2153+
transformer.load_lora_adapter(
2154+
state_dict,
2155+
network_alphas=network_alphas,
2156+
adapter_name=adapter_name,
2157+
_pipeline=_pipeline,
2158+
low_cpu_mem_usage=low_cpu_mem_usage,
2159+
)
22212160

22222161
@classmethod
22232162
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder

0 commit comments

Comments
 (0)