Skip to content

Commit 9bbc6dc

Browse files
committed
update
1 parent d952267 commit 9bbc6dc

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,17 @@ def save_lora_weights(
10051005
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
10061006

10071007
if unet_lora_adapter_metadata is not None:
1008-
lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name))
1008+
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
10091009

10101010
if text_encoder_lora_adapter_metadata:
1011-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
1011+
lora_adapter_metadata.update(
1012+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1013+
)
10121014

10131015
if text_encoder_2_lora_adapter_metadata:
1014-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
1016+
lora_adapter_metadata.update(
1017+
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1018+
)
10151019

10161020
cls.write_lora_layers(
10171021
state_dict=state_dict,
@@ -1459,13 +1463,19 @@ def save_lora_weights(
14591463
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
14601464

14611465
if transformer_lora_adapter_metadata is not None:
1462-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
1466+
lora_adapter_metadata.update(
1467+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
1468+
)
14631469

14641470
if text_encoder_lora_adapter_metadata:
1465-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
1471+
lora_adapter_metadata.update(
1472+
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
1473+
)
14661474

14671475
if text_encoder_2_lora_adapter_metadata:
1468-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
1476+
lora_adapter_metadata.update(
1477+
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
1478+
)
14691479

14701480
cls.write_lora_layers(
14711481
state_dict=state_dict,
@@ -1804,7 +1814,9 @@ def save_lora_weights(
18041814
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
18051815

18061816
if transformer_lora_adapter_metadata is not None:
1807-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
1817+
lora_adapter_metadata.update(
1818+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
1819+
)
18081820

18091821
# Save the model
18101822
cls.write_lora_layers(
@@ -2376,7 +2388,9 @@ def save_lora_weights(
23762388
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
23772389

23782390
if transformer_lora_adapter_metadata:
2379-
lora_adapter_metadata.update(_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name))
2391+
lora_adapter_metadata.update(
2392+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
2393+
)
23802394

23812395
if text_encoder_lora_adapter_metadata:
23822396
lora_adapter_metadata.update(
@@ -3173,7 +3187,9 @@ def save_lora_weights(
31733187
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
31743188

31753189
if transformer_lora_adapter_metadata is not None:
3176-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3190+
lora_adapter_metadata.update(
3191+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3192+
)
31773193

31783194
# Save the model
31793195
cls.write_lora_layers(
@@ -3508,7 +3524,9 @@ def save_lora_weights(
35083524
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
35093525

35103526
if transformer_lora_adapter_metadata is not None:
3511-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3527+
lora_adapter_metadata.update(
3528+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3529+
)
35123530

35133531
# Save the model
35143532
cls.write_lora_layers(
@@ -3847,7 +3865,9 @@ def save_lora_weights(
38473865
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
38483866

38493867
if transformer_lora_adapter_metadata is not None:
3850-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
3868+
lora_adapter_metadata.update(
3869+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
3870+
)
38513871

38523872
# Save the model
38533873
cls.write_lora_layers(
@@ -4184,7 +4204,9 @@ def save_lora_weights(
41844204
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
41854205

41864206
if transformer_lora_adapter_metadata is not None:
4187-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4207+
lora_adapter_metadata.update(
4208+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4209+
)
41884210

41894211
# Save the model
41904212
cls.write_lora_layers(
@@ -4523,7 +4545,9 @@ def save_lora_weights(
45234545
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
45244546

45254547
if transformer_lora_adapter_metadata is not None:
4526-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4548+
lora_adapter_metadata.update(
4549+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4550+
)
45274551

45284552
# Save the model
45294553
cls.write_lora_layers(
@@ -4863,7 +4887,9 @@ def save_lora_weights(
48634887
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
48644888

48654889
if transformer_lora_adapter_metadata is not None:
4866-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
4890+
lora_adapter_metadata.update(
4891+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
4892+
)
48674893

48684894
# Save the model
48694895
cls.write_lora_layers(
@@ -5253,7 +5279,9 @@ def save_lora_weights(
52535279
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
52545280

52555281
if transformer_lora_adapter_metadata is not None:
5256-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
5282+
lora_adapter_metadata.update(
5283+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
5284+
)
52575285

52585286
# Save the model
52595287
cls.write_lora_layers(
@@ -5590,7 +5618,9 @@ def save_lora_weights(
55905618
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
55915619

55925620
if transformer_lora_adapter_metadata is not None:
5593-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
5621+
lora_adapter_metadata.update(
5622+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
5623+
)
55945624

55955625
# Save the model
55965626
cls.write_lora_layers(
@@ -5929,7 +5959,9 @@ def save_lora_weights(
59295959
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
59305960

59315961
if transformer_lora_adapter_metadata is not None:
5932-
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
5962+
lora_adapter_metadata.update(
5963+
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
5964+
)
59335965

59345966
# Save the model
59355967
cls.write_lora_layers(

0 commit comments

Comments
 (0)