Skip to content

Commit 116306e

Browse files
committed
fix-copies
1 parent 7bb6c9f commit 116306e

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,6 @@ def lora_state_dict(
737737
"""
738738
# Load the main state dict first which has the LoRA layers for either of
739739
# UNet and text encoder or both.
740-
741740
cache_dir = kwargs.pop("cache_dir", None)
742741
force_download = kwargs.pop("force_download", False)
743742
proxies = kwargs.pop("proxies", None)
@@ -1355,6 +1354,9 @@ def save_lora_weights(
13551354
weight_name: str = None,
13561355
save_function: Callable = None,
13571356
safe_serialization: bool = True,
1357+
transformer_lora_adapter_metadata=None,
1358+
text_encoder_lora_adapter_metadata=None,
1359+
text_encoder_2_lora_adapter_metadata=None,
13581360
):
13591361
r"""
13601362
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1380,8 +1382,12 @@ def save_lora_weights(
13801382
`DIFFUSERS_SAVE_MODE`.
13811383
safe_serialization (`bool`, *optional*, defaults to `True`):
13821384
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1385+
transformer_lora_adapter_metadata: TODO
1386+
text_encoder_lora_adapter_metadata: TODO
1387+
text_encoder_2_lora_adapter_metadata: TODO
13831388
"""
13841389
state_dict = {}
1390+
lora_adapter_metadata = {}
13851391

13861392
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
13871393
raise ValueError(
@@ -1397,13 +1403,23 @@ def save_lora_weights(
13971403
if text_encoder_2_lora_layers:
13981404
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
13991405

1406+
if transformer_lora_adapter_metadata is not None:
1407+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
1408+
1409+
if text_encoder_lora_adapter_metadata:
1410+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
1411+
1412+
if text_encoder_2_lora_adapter_metadata:
1413+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
1414+
14001415
cls.write_lora_layers(
14011416
state_dict=state_dict,
14021417
save_directory=save_directory,
14031418
is_main_process=is_main_process,
14041419
weight_name=weight_name,
14051420
save_function=save_function,
14061421
safe_serialization=safe_serialization,
1422+
lora_adapter_metadata=lora_adapter_metadata,
14071423
)
14081424

14091425
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer

0 commit comments

Comments
 (0)