Skip to content

Commit c0f4585

Browse files
committed
updates
1 parent 2e70a93 commit c0f4585

File tree

2 files changed

+48
-33
lines changed

2 files changed

+48
-33
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,14 +1326,17 @@ def load_lora_into_transformer(
13261326
)
13271327

13281328
# Load the layers corresponding to transformer.
1329-
logger.info(f"Loading {cls.transformer_name}.")
1330-
transformer.load_lora_adapter(
1331-
state_dict,
1332-
network_alphas=None,
1333-
adapter_name=adapter_name,
1334-
_pipeline=_pipeline,
1335-
low_cpu_mem_usage=low_cpu_mem_usage,
1336-
)
1329+
keys = list(state_dict.keys())
1330+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
1331+
if not only_text_encoder:
1332+
logger.info(f"Loading {cls.transformer_name}.")
1333+
transformer.load_lora_adapter(
1334+
state_dict,
1335+
network_alphas=None,
1336+
adapter_name=adapter_name,
1337+
_pipeline=_pipeline,
1338+
low_cpu_mem_usage=low_cpu_mem_usage,
1339+
)
13371340

13381341
@classmethod
13391342
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1845,14 +1848,17 @@ def load_lora_into_transformer(
18451848
)
18461849

18471850
# Load the layers corresponding to transformer.
1848-
logger.info(f"Loading {cls.transformer_name}.")
1849-
transformer.load_lora_adapter(
1850-
state_dict,
1851-
network_alphas=network_alphas,
1852-
adapter_name=adapter_name,
1853-
_pipeline=_pipeline,
1854-
low_cpu_mem_usage=low_cpu_mem_usage,
1855-
)
1851+
keys = list(state_dict.keys())
1852+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
1853+
if not only_text_encoder:
1854+
logger.info(f"Loading {cls.transformer_name}.")
1855+
transformer.load_lora_adapter(
1856+
state_dict,
1857+
network_alphas=network_alphas,
1858+
adapter_name=adapter_name,
1859+
_pipeline=_pipeline,
1860+
low_cpu_mem_usage=low_cpu_mem_usage,
1861+
)
18561862

18571863
@classmethod
18581864
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2118,7 +2124,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
21182124
text_encoder_name = TEXT_ENCODER_NAME
21192125

21202126
@classmethod
2121-
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
2127+
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
21222128
def load_lora_into_transformer(
21232129
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
21242130
):
@@ -2149,14 +2155,17 @@ def load_lora_into_transformer(
21492155
)
21502156

21512157
# Load the layers corresponding to transformer.
2152-
logger.info(f"Loading {cls.transformer_name}.")
2153-
transformer.load_lora_adapter(
2154-
state_dict,
2155-
network_alphas=network_alphas,
2156-
adapter_name=adapter_name,
2157-
_pipeline=_pipeline,
2158-
low_cpu_mem_usage=low_cpu_mem_usage,
2159-
)
2158+
keys = list(state_dict.keys())
2159+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
2160+
if not only_text_encoder:
2161+
logger.info(f"Loading {cls.transformer_name}.")
2162+
transformer.load_lora_adapter(
2163+
state_dict,
2164+
network_alphas=network_alphas,
2165+
adapter_name=adapter_name,
2166+
_pipeline=_pipeline,
2167+
low_cpu_mem_usage=low_cpu_mem_usage,
2168+
)
21602169

21612170
@classmethod
21622171
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2537,14 +2546,17 @@ def load_lora_into_transformer(
25372546
)
25382547

25392548
# Load the layers corresponding to transformer.
2540-
logger.info(f"Loading {cls.transformer_name}.")
2541-
transformer.load_lora_adapter(
2542-
state_dict,
2543-
network_alphas=None,
2544-
adapter_name=adapter_name,
2545-
_pipeline=_pipeline,
2546-
low_cpu_mem_usage=low_cpu_mem_usage,
2547-
)
2549+
keys = list(state_dict.keys())
2550+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
2551+
if not only_text_encoder:
2552+
logger.info(f"Loading {cls.transformer_name}.")
2553+
transformer.load_lora_adapter(
2554+
state_dict,
2555+
network_alphas=None,
2556+
adapter_name=adapter_name,
2557+
_pipeline=_pipeline,
2558+
low_cpu_mem_usage=low_cpu_mem_usage,
2559+
)
25482560

25492561
@classmethod
25502562
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder

src/diffusers/loaders/peft.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def _optionally_disable_offloading(cls, _pipeline):
104104
return (is_model_cpu_offload, is_sequential_cpu_offload)
105105

106106
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
107+
"""
108+
TODO
109+
"""
107110
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
108111

109112
cache_dir = kwargs.pop("cache_dir", None)

0 commit comments

Comments
 (0)