Skip to content

Commit c0c8856

Browse files
committed
factor out the overlaps in save_lora_weights().
1 parent 9d313fc commit c0c8856

File tree

2 files changed

+90
-88
lines changed

2 files changed

+90
-88
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,44 @@ def save_function(weights, filename):
10601060
save_function(state_dict, save_path)
10611061
logger.info(f"Model weights saved in {save_path}")
10621062

1063+
@classmethod
1064+
def _save_lora_weights(
1065+
cls,
1066+
save_directory: Union[str, os.PathLike],
1067+
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
1068+
lora_metadata: Dict[str, Optional[dict]],
1069+
is_main_process: bool = True,
1070+
weight_name: str = None,
1071+
save_function: Callable = None,
1072+
safe_serialization: bool = True,
1073+
):
1074+
"""
1075+
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
1076+
pipeline types.
1077+
"""
1078+
state_dict = {}
1079+
final_lora_adapter_metadata = {}
1080+
1081+
# Pack the weights for each component (e.g., 'unet', 'text_encoder')
1082+
for prefix, layers in lora_layers.items():
1083+
state_dict.update(cls.pack_weights(layers, prefix))
1084+
1085+
# Pack the metadata for each component
1086+
for prefix, metadata in lora_metadata.items():
1087+
if metadata:
1088+
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
1089+
1090+
# Call the existing writer function
1091+
cls.write_lora_layers(
1092+
state_dict=state_dict,
1093+
save_directory=save_directory,
1094+
is_main_process=is_main_process,
1095+
weight_name=weight_name,
1096+
save_function=save_function,
1097+
safe_serialization=safe_serialization,
1098+
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
1099+
)
1100+
10631101
@classmethod
10641102
def _optionally_disable_offloading(cls, _pipeline):
10651103
return _func_optionally_disable_offloading(_pipeline=_pipeline)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -509,35 +509,29 @@ def save_lora_weights(
509509
text_encoder_lora_adapter_metadata:
510510
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
511511
"""
512-
state_dict = {}
513-
lora_adapter_metadata = {}
514-
515-
if not (unet_lora_layers or text_encoder_lora_layers):
516-
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
512+
lora_layers = {}
513+
lora_metadata = {}
517514

518515
if unet_lora_layers:
519-
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
516+
lora_layers[cls.unet_name] = unet_lora_layers
517+
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
520518

521519
if text_encoder_lora_layers:
522-
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
520+
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
521+
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
523522

524-
if unet_lora_adapter_metadata:
525-
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
523+
if not lora_layers:
524+
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
526525

527-
if text_encoder_lora_adapter_metadata:
528-
lora_adapter_metadata.update(
529-
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
530-
)
531-
532-
# Save the model
533-
cls.write_lora_layers(
534-
state_dict=state_dict,
526+
# Delegate to the base helper method
527+
cls._save_lora_weights(
535528
save_directory=save_directory,
529+
lora_layers=lora_layers,
530+
lora_metadata=lora_metadata,
536531
is_main_process=is_main_process,
537532
weight_name=weight_name,
538533
save_function=save_function,
539534
safe_serialization=safe_serialization,
540-
lora_adapter_metadata=lora_adapter_metadata,
541535
)
542536

543537
def fuse_lora(
@@ -1003,44 +997,34 @@ def save_lora_weights(
1003997
text_encoder_2_lora_adapter_metadata:
1004998
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
1005999
"""
1006-
state_dict = {}
1007-
lora_adapter_metadata = {}
1008-
1009-
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1010-
raise ValueError(
1011-
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
1012-
)
1000+
lora_layers = {}
1001+
lora_metadata = {}
10131002

10141003
if unet_lora_layers:
1015-
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
1004+
lora_layers[cls.unet_name] = unet_lora_layers
1005+
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
10161006

10171007
if text_encoder_lora_layers:
1018-
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
1008+
lora_layers["text_encoder"] = text_encoder_lora_layers
1009+
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
10191010

10201011
if text_encoder_2_lora_layers:
1021-
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1012+
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
1013+
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
10221014

1023-
if unet_lora_adapter_metadata is not None:
1024-
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
1025-
1026-
if text_encoder_lora_adapter_metadata:
1027-
lora_adapter_metadata.update(
1028-
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1029-
)
1030-
1031-
if text_encoder_2_lora_adapter_metadata:
1032-
lora_adapter_metadata.update(
1033-
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1015+
if not lora_layers:
1016+
raise ValueError(
1017+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
10341018
)
10351019

1036-
cls.write_lora_layers(
1037-
state_dict=state_dict,
1020+
cls._save_lora_weights(
10381021
save_directory=save_directory,
1022+
lora_layers=lora_layers,
1023+
lora_metadata=lora_metadata,
10391024
is_main_process=is_main_process,
10401025
weight_name=weight_name,
10411026
save_function=save_function,
10421027
safe_serialization=safe_serialization,
1043-
lora_adapter_metadata=lora_adapter_metadata,
10441028
)
10451029

10461030
def fuse_lora(
@@ -1466,46 +1450,34 @@ def save_lora_weights(
14661450
text_encoder_2_lora_adapter_metadata:
14671451
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
14681452
"""
1469-
state_dict = {}
1470-
lora_adapter_metadata = {}
1471-
1472-
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1473-
raise ValueError(
1474-
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
1475-
)
1453+
lora_layers = {}
1454+
lora_metadata = {}
14761455

14771456
if transformer_lora_layers:
1478-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1457+
lora_layers[cls.transformer_name] = transformer_lora_layers
1458+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
14791459

14801460
if text_encoder_lora_layers:
1481-
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
1461+
lora_layers["text_encoder"] = text_encoder_lora_layers
1462+
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
14821463

14831464
if text_encoder_2_lora_layers:
1484-
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1465+
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
1466+
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
14851467

1486-
if transformer_lora_adapter_metadata is not None:
1487-
lora_adapter_metadata.update(
1488-
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
1489-
)
1490-
1491-
if text_encoder_lora_adapter_metadata:
1492-
lora_adapter_metadata.update(
1493-
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1494-
)
1495-
1496-
if text_encoder_2_lora_adapter_metadata:
1497-
lora_adapter_metadata.update(
1498-
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1468+
if not lora_layers:
1469+
raise ValueError(
1470+
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
14991471
)
15001472

1501-
cls.write_lora_layers(
1502-
state_dict=state_dict,
1473+
cls._save_lora_weights(
15031474
save_directory=save_directory,
1475+
lora_layers=lora_layers,
1476+
lora_metadata=lora_metadata,
15041477
is_main_process=is_main_process,
15051478
weight_name=weight_name,
15061479
save_function=save_function,
15071480
safe_serialization=safe_serialization,
1508-
lora_adapter_metadata=lora_adapter_metadata,
15091481
)
15101482

15111483
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -2434,37 +2406,29 @@ def save_lora_weights(
24342406
text_encoder_lora_adapter_metadata:
24352407
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
24362408
"""
2437-
state_dict = {}
2438-
lora_adapter_metadata = {}
2439-
2440-
if not (transformer_lora_layers or text_encoder_lora_layers):
2441-
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
2409+
lora_layers = {}
2410+
lora_metadata = {}
24422411

24432412
if transformer_lora_layers:
2444-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2413+
lora_layers[cls.transformer_name] = transformer_lora_layers
2414+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
24452415

24462416
if text_encoder_lora_layers:
2447-
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
2448-
2449-
if transformer_lora_adapter_metadata:
2450-
lora_adapter_metadata.update(
2451-
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
2452-
)
2417+
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
2418+
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
24532419

2454-
if text_encoder_lora_adapter_metadata:
2455-
lora_adapter_metadata.update(
2456-
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
2457-
)
2420+
if not lora_layers:
2421+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
24582422

2459-
# Save the model
2460-
cls.write_lora_layers(
2461-
state_dict=state_dict,
2423+
# Delegate to the base helper method
2424+
cls._save_lora_weights(
24622425
save_directory=save_directory,
2426+
lora_layers=lora_layers,
2427+
lora_metadata=lora_metadata,
24632428
is_main_process=is_main_process,
24642429
weight_name=weight_name,
24652430
save_function=save_function,
24662431
safe_serialization=safe_serialization,
2467-
lora_adapter_metadata=lora_adapter_metadata,
24682432
)
24692433

24702434
def fuse_lora(

0 commit comments

Comments
 (0)