Skip to content

Commit 1db7503

Browse files
committed
updates
1 parent b694ca4 commit 1db7503

File tree

3 files changed

+106
-116
lines changed

3 files changed

+106
-116
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 81 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -294,20 +294,15 @@ def load_lora_into_unet(
294294
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
295295
)
296296

297-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
298-
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
299-
# their prefixes.
300-
if any(k.startswith(f"{cls.unet_name}.") for k in state_dict):
301-
# Load the layers corresponding to UNet.
302-
logger.info(f"Loading {cls.unet_name}.")
303-
unet.load_lora_adapter(
304-
state_dict,
305-
prefix=cls.unet_name,
306-
network_alphas=network_alphas,
307-
adapter_name=adapter_name,
308-
_pipeline=_pipeline,
309-
low_cpu_mem_usage=low_cpu_mem_usage,
310-
)
297+
# Load the layers corresponding to UNet.
298+
unet.load_lora_adapter(
299+
state_dict,
300+
prefix=cls.unet_name,
301+
network_alphas=network_alphas,
302+
adapter_name=adapter_name,
303+
_pipeline=_pipeline,
304+
low_cpu_mem_usage=low_cpu_mem_usage,
305+
)
311306

312307
@classmethod
313308
def load_lora_into_text_encoder(
@@ -663,18 +658,16 @@ def load_lora_weights(
663658
_pipeline=self,
664659
low_cpu_mem_usage=low_cpu_mem_usage,
665660
)
666-
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
667-
if len(text_encoder_state_dict) > 0:
668-
self.load_lora_into_text_encoder(
669-
text_encoder_state_dict,
670-
network_alphas=network_alphas,
671-
text_encoder=self.text_encoder,
672-
prefix="text_encoder",
673-
lora_scale=self.lora_scale,
674-
adapter_name=adapter_name,
675-
_pipeline=self,
676-
low_cpu_mem_usage=low_cpu_mem_usage,
677-
)
661+
self.load_lora_into_text_encoder(
662+
state_dict,
663+
network_alphas=network_alphas,
664+
text_encoder=self.text_encoder,
665+
prefix="text_encoder",
666+
lora_scale=self.lora_scale,
667+
adapter_name=adapter_name,
668+
_pipeline=self,
669+
low_cpu_mem_usage=low_cpu_mem_usage,
670+
)
678671

679672
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
680673
if len(text_encoder_2_state_dict) > 0:
@@ -839,20 +832,15 @@ def load_lora_into_unet(
839832
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
840833
)
841834

842-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
843-
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
844-
# their prefixes.
845-
if any(k.startswith(f"{cls.unet_name}.") for k in state_dict):
846-
# Load the layers corresponding to UNet.
847-
logger.info(f"Loading {cls.unet_name}.")
848-
unet.load_lora_adapter(
849-
state_dict,
850-
prefix=cls.unet_name,
851-
network_alphas=network_alphas,
852-
adapter_name=adapter_name,
853-
_pipeline=_pipeline,
854-
low_cpu_mem_usage=low_cpu_mem_usage,
855-
)
835+
# Load the layers corresponding to UNet.
836+
unet.load_lora_adapter(
837+
state_dict,
838+
prefix=cls.unet_name,
839+
network_alphas=network_alphas,
840+
adapter_name=adapter_name,
841+
_pipeline=_pipeline,
842+
low_cpu_mem_usage=low_cpu_mem_usage,
843+
)
856844

857845
@classmethod
858846
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -1294,43 +1282,35 @@ def load_lora_weights(
12941282
if not is_correct_format:
12951283
raise ValueError("Invalid LoRA checkpoint.")
12961284

1297-
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1298-
if len(transformer_state_dict) > 0:
1299-
self.load_lora_into_transformer(
1300-
state_dict,
1301-
transformer=getattr(self, self.transformer_name)
1302-
if not hasattr(self, "transformer")
1303-
else self.transformer,
1304-
adapter_name=adapter_name,
1305-
_pipeline=self,
1306-
low_cpu_mem_usage=low_cpu_mem_usage,
1307-
)
1285+
self.load_lora_into_transformer(
1286+
state_dict,
1287+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1288+
adapter_name=adapter_name,
1289+
_pipeline=self,
1290+
low_cpu_mem_usage=low_cpu_mem_usage,
1291+
)
13081292

1309-
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1310-
if len(text_encoder_state_dict) > 0:
1311-
self.load_lora_into_text_encoder(
1312-
text_encoder_state_dict,
1313-
network_alphas=None,
1314-
text_encoder=self.text_encoder,
1315-
prefix="text_encoder",
1316-
lora_scale=self.lora_scale,
1317-
adapter_name=adapter_name,
1318-
_pipeline=self,
1319-
low_cpu_mem_usage=low_cpu_mem_usage,
1320-
)
1293+
self.load_lora_into_text_encoder(
1294+
state_dict,
1295+
network_alphas=None,
1296+
text_encoder=self.text_encoder,
1297+
prefix="text_encoder",
1298+
lora_scale=self.lora_scale,
1299+
adapter_name=adapter_name,
1300+
_pipeline=self,
1301+
low_cpu_mem_usage=low_cpu_mem_usage,
1302+
)
13211303

1322-
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1323-
if len(text_encoder_2_state_dict) > 0:
1324-
self.load_lora_into_text_encoder(
1325-
text_encoder_2_state_dict,
1326-
network_alphas=None,
1327-
text_encoder=self.text_encoder_2,
1328-
prefix="text_encoder_2",
1329-
lora_scale=self.lora_scale,
1330-
adapter_name=adapter_name,
1331-
_pipeline=self,
1332-
low_cpu_mem_usage=low_cpu_mem_usage,
1333-
)
1304+
self.load_lora_into_text_encoder(
1305+
state_dict,
1306+
network_alphas=None,
1307+
text_encoder=self.text_encoder_2,
1308+
prefix="text_encoder_2",
1309+
lora_scale=self.lora_scale,
1310+
adapter_name=adapter_name,
1311+
_pipeline=self,
1312+
low_cpu_mem_usage=low_cpu_mem_usage,
1313+
)
13341314

13351315
@classmethod
13361316
def load_lora_into_transformer(
@@ -1359,7 +1339,6 @@ def load_lora_into_transformer(
13591339
)
13601340

13611341
# Load the layers corresponding to transformer.
1362-
logger.info(f"Loading {cls.transformer_name}.")
13631342
transformer.load_lora_adapter(
13641343
state_dict,
13651344
network_alphas=None,
@@ -1855,7 +1834,7 @@ def load_lora_weights(
18551834
raise ValueError("Invalid LoRA checkpoint.")
18561835

18571836
transformer_lora_state_dict = {
1858-
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
1837+
k: state_dict.get(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
18591838
}
18601839
transformer_norm_state_dict = {
18611840
k: state_dict.pop(k)
@@ -1875,15 +1854,14 @@ def load_lora_weights(
18751854
"To get a comprehensive list of parameter names that were modified, enable debug logging."
18761855
)
18771856

1878-
if len(transformer_lora_state_dict) > 0:
1879-
self.load_lora_into_transformer(
1880-
transformer_lora_state_dict,
1881-
network_alphas=network_alphas,
1882-
transformer=transformer,
1883-
adapter_name=adapter_name,
1884-
_pipeline=self,
1885-
low_cpu_mem_usage=low_cpu_mem_usage,
1886-
)
1857+
self.load_lora_into_transformer(
1858+
state_dict,
1859+
network_alphas=network_alphas,
1860+
transformer=transformer,
1861+
adapter_name=adapter_name,
1862+
_pipeline=self,
1863+
low_cpu_mem_usage=low_cpu_mem_usage,
1864+
)
18871865

18881866
if len(transformer_norm_state_dict) > 0:
18891867
transformer._transformer_norm_layers = self._load_norm_into_transformer(
@@ -1892,18 +1870,16 @@ def load_lora_weights(
18921870
discard_original_layers=False,
18931871
)
18941872

1895-
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1896-
if len(text_encoder_state_dict) > 0:
1897-
self.load_lora_into_text_encoder(
1898-
text_encoder_state_dict,
1899-
network_alphas=network_alphas,
1900-
text_encoder=self.text_encoder,
1901-
prefix="text_encoder",
1902-
lora_scale=self.lora_scale,
1903-
adapter_name=adapter_name,
1904-
_pipeline=self,
1905-
low_cpu_mem_usage=low_cpu_mem_usage,
1906-
)
1873+
self.load_lora_into_text_encoder(
1874+
state_dict,
1875+
network_alphas=network_alphas,
1876+
text_encoder=self.text_encoder,
1877+
prefix="text_encoder",
1878+
lora_scale=self.lora_scale,
1879+
adapter_name=adapter_name,
1880+
_pipeline=self,
1881+
low_cpu_mem_usage=low_cpu_mem_usage,
1882+
)
19071883

19081884
@classmethod
19091885
def load_lora_into_transformer(
@@ -1936,17 +1912,13 @@ def load_lora_into_transformer(
19361912
)
19371913

19381914
# Load the layers corresponding to transformer.
1939-
keys = list(state_dict.keys())
1940-
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
1941-
if transformer_present:
1942-
logger.info(f"Loading {cls.transformer_name}.")
1943-
transformer.load_lora_adapter(
1944-
state_dict,
1945-
network_alphas=network_alphas,
1946-
adapter_name=adapter_name,
1947-
_pipeline=_pipeline,
1948-
low_cpu_mem_usage=low_cpu_mem_usage,
1949-
)
1915+
transformer.load_lora_adapter(
1916+
state_dict,
1917+
network_alphas=network_alphas,
1918+
adapter_name=adapter_name,
1919+
_pipeline=_pipeline,
1920+
low_cpu_mem_usage=low_cpu_mem_usage,
1921+
)
19501922

19511923
@classmethod
19521924
def _load_norm_into_transformer(
@@ -2837,7 +2809,6 @@ def load_lora_into_transformer(
28372809
)
28382810

28392811
# Load the layers corresponding to transformer.
2840-
logger.info(f"Loading {cls.transformer_name}.")
28412812
transformer.load_lora_adapter(
28422813
state_dict,
28432814
network_alphas=None,
@@ -3145,7 +3116,6 @@ def load_lora_into_transformer(
31453116
)
31463117

31473118
# Load the layers corresponding to transformer.
3148-
logger.info(f"Loading {cls.transformer_name}.")
31493119
transformer.load_lora_adapter(
31503120
state_dict,
31513121
network_alphas=None,

src/diffusers/loaders/peft.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
257257
state_dict = {}
258258

259259
if len(state_dict) > 0:
260+
if prefix is None:
261+
component_name = "unet" if "UNet" in self.__class__.__name__ else "transformer"
262+
else:
263+
component_name = prefix
264+
logger.info(f"Loading {component_name}.")
260265
if adapter_name in getattr(self, "peft_config", {}):
261266
raise ValueError(
262267
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."

tests/lora/utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,11 +1898,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18981898
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
18991899

19001900
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
1901-
logger = (
1902-
logging.get_logger("diffusers.loaders.lora_pipeline")
1903-
if "text_encoder" in self.pipeline_class._lora_loadable_modules
1904-
else logging.get_logger("diffusers.loaders.peft")
1905-
)
1901+
logger = logging.get_logger("diffusers.loaders.peft")
19061902
logger.setLevel(logging.INFO)
19071903

19081904
with CaptureLogger(logger) as cap_logger:
@@ -1911,3 +1907,22 @@ def test_logs_info_when_no_lora_keys_found(self):
19111907

19121908
self.assertTrue(cap_logger.out.startswith("No LoRA keys found in the provided state dict"))
19131909
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
1910+
1911+
# test only for text encoder
1912+
for lora_module in self.pipeline_class._lora_loadable_modules:
1913+
if "text_encoder" in lora_module:
1914+
text_encoder = getattr(pipe, lora_module)
1915+
if lora_module == "text_encoder":
1916+
prefix = text_encoder
1917+
elif lora_module == "text_encoder_2":
1918+
prefix = "text_encoder_2"
1919+
1920+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
1921+
logger.setLevel(logging.INFO)
1922+
1923+
with CaptureLogger(logger) as cap_logger:
1924+
self.pipeline_class.load_lora_into_text_encoder(
1925+
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
1926+
)
1927+
1928+
self.assertTrue(cap_logger.out.startswith("No LoRA keys found in the provided state dict"))

0 commit comments

Comments
 (0)