Skip to content

Commit 4579523

Browse files
committed
up
1 parent bb6feae commit 4579523

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3215,28 +3215,24 @@ def save_lora_weights(
32153215
transformer_lora_adapter_metadata:
32163216
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
32173217
"""
3218-
state_dict = {}
3219-
lora_adapter_metadata = {}
3220-
3221-
if not transformer_lora_layers:
3222-
raise ValueError("You must pass `transformer_lora_layers`.")
3218+
lora_layers = {}
3219+
lora_metadata = {}
32233220

3224-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3221+
if transformer_lora_layers:
3222+
lora_layers[cls.transformer_name] = transformer_lora_layers
3223+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
32253224

3226-
if transformer_lora_adapter_metadata is not None:
3227-
lora_adapter_metadata.update(
3228-
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3229-
)
3225+
if not lora_layers:
3226+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
32303227

3231-
# Save the model
3232-
cls.write_lora_layers(
3233-
state_dict=state_dict,
3228+
cls._save_lora_weights(
32343229
save_directory=save_directory,
3230+
lora_layers=lora_layers,
3231+
lora_metadata=lora_metadata,
32353232
is_main_process=is_main_process,
32363233
weight_name=weight_name,
32373234
save_function=save_function,
32383235
safe_serialization=safe_serialization,
3239-
lora_adapter_metadata=lora_adapter_metadata,
32403236
)
32413237

32423238
def fuse_lora(

0 commit comments

Comments
 (0)