Skip to content

Commit f6fde6f

Browse files
committed
docs.
1 parent ae0580a commit f6fde6f

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,8 @@ def save_lora_weights(
22672267
weight_name: str = None,
22682268
save_function: Callable = None,
22692269
safe_serialization: bool = True,
2270+
transformer_lora_adapter_metadata=None,
2271+
text_encoder_lora_adapter_metadata=None,
22702272
):
22712273
r"""
22722274
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -2289,8 +2291,11 @@ def save_lora_weights(
22892291
`DIFFUSERS_SAVE_MODE`.
22902292
safe_serialization (`bool`, *optional*, defaults to `True`):
22912293
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2294+
transformer_lora_adapter_metadata: TODO
2295+
text_encoder_lora_adapter_metadata: TODO
22922296
"""
22932297
state_dict = {}
2298+
lora_adapter_metadata = {}
22942299

22952300
if not (transformer_lora_layers or text_encoder_lora_layers):
22962301
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
@@ -2301,6 +2306,12 @@ def save_lora_weights(
23012306
if text_encoder_lora_layers:
23022307
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
23032308

2309+
if transformer_lora_adapter_metadata is not None:
2310+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
2311+
2312+
if text_encoder_lora_adapter_metadata:
2313+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
2314+
23042315
# Save the model
23052316
cls.write_lora_layers(
23062317
state_dict=state_dict,
@@ -2309,6 +2320,7 @@ def save_lora_weights(
23092320
weight_name=weight_name,
23102321
save_function=save_function,
23112322
safe_serialization=safe_serialization,
2323+
lora_adapter_metadata=lora_adapter_metadata,
23122324
)
23132325

23142326
def fuse_lora(

0 commit comments

Comments
 (0)