Skip to content

Commit 0ac1a39

Browse files
committed
updates
1 parent 7716303 commit 0ac1a39

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def save_function(weights, filename):
906906
# We need to be able to serialize the NoneTypes too, otherwise we run into
907907
# 'NoneType' object cannot be converted to 'PyString'
908908
metadata = {"format": "pt"}
909-
if lora_adapter_metadata is not None:
909+
if lora_adapter_metadata:
910910
for key, value in lora_adapter_metadata.items():
911911
if isinstance(value, set):
912912
lora_adapter_metadata[key] = list(value)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,10 +1695,9 @@ def save_lora_weights(
16951695
if not transformer_lora_layers:
16961696
raise ValueError("You must pass `transformer_lora_layers`.")
16971697

1698-
if transformer_lora_layers:
1699-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1698+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
17001699

1701-
if transformer_lora_adapter_metadata:
1700+
if transformer_lora_adapter_metadata is not None:
17021701
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
17031702

17041703
# Save the model
@@ -3020,10 +3019,9 @@ def save_lora_weights(
30203019
if not transformer_lora_layers:
30213020
raise ValueError("You must pass `transformer_lora_layers`.")
30223021

3023-
if transformer_lora_layers:
3024-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3022+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
30253023

3026-
if transformer_lora_adapter_metadata:
3024+
if transformer_lora_adapter_metadata is not None:
30273025
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
30283026

30293027
# Save the model
@@ -3344,10 +3342,9 @@ def save_lora_weights(
33443342
if not transformer_lora_layers:
33453343
raise ValueError("You must pass `transformer_lora_layers`.")
33463344

3347-
if transformer_lora_layers:
3348-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3345+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
33493346

3350-
if transformer_lora_adapter_metadata:
3347+
if transformer_lora_adapter_metadata is not None:
33513348
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
33523349

33533350
# Save the model
@@ -3670,10 +3667,9 @@ def save_lora_weights(
36703667
if not transformer_lora_layers:
36713668
raise ValueError("You must pass `transformer_lora_layers`.")
36723669

3673-
if transformer_lora_layers:
3674-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3670+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
36753671

3676-
if transformer_lora_adapter_metadata:
3672+
if transformer_lora_adapter_metadata is not None:
36773673
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
36783674

36793675
# Save the model
@@ -3996,10 +3992,9 @@ def save_lora_weights(
39963992
if not transformer_lora_layers:
39973993
raise ValueError("You must pass `transformer_lora_layers`.")
39983994

3999-
if transformer_lora_layers:
4000-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
3995+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
40013996

4002-
if transformer_lora_adapter_metadata:
3997+
if transformer_lora_adapter_metadata is not None:
40033998
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
40043999

40054000
# Save the model
@@ -4325,10 +4320,9 @@ def save_lora_weights(
43254320
if not transformer_lora_layers:
43264321
raise ValueError("You must pass `transformer_lora_layers`.")
43274322

4328-
if transformer_lora_layers:
4329-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4323+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
43304324

4331-
if transformer_lora_adapter_metadata:
4325+
if transformer_lora_adapter_metadata is not None:
43324326
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
43334327

43344328
# Save the model
@@ -4655,10 +4649,9 @@ def save_lora_weights(
46554649
if not transformer_lora_layers:
46564650
raise ValueError("You must pass `transformer_lora_layers`.")
46574651

4658-
if transformer_lora_layers:
4659-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4652+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
46604653

4661-
if transformer_lora_adapter_metadata:
4654+
if transformer_lora_adapter_metadata is not None:
46624655
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
46634656

46644657
# Save the model
@@ -5014,10 +5007,9 @@ def save_lora_weights(
50145007
if not transformer_lora_layers:
50155008
raise ValueError("You must pass `transformer_lora_layers`.")
50165009

5017-
if transformer_lora_layers:
5018-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5010+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
50195011

5020-
if transformer_lora_adapter_metadata:
5012+
if transformer_lora_adapter_metadata is not None:
50215013
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
50225014

50235015
# Save the model
@@ -5340,10 +5332,9 @@ def save_lora_weights(
53405332
if not transformer_lora_layers:
53415333
raise ValueError("You must pass `transformer_lora_layers`.")
53425334

5343-
if transformer_lora_layers:
5344-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5335+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
53455336

5346-
if transformer_lora_adapter_metadata:
5337+
if transformer_lora_adapter_metadata is not None:
53475338
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
53485339

53495340
# Save the model
@@ -5666,10 +5657,9 @@ def save_lora_weights(
56665657
if not transformer_lora_layers:
56675658
raise ValueError("You must pass `transformer_lora_layers`.")
56685659

5669-
if transformer_lora_layers:
5670-
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
5660+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
56715661

5672-
if transformer_lora_adapter_metadata:
5662+
if transformer_lora_adapter_metadata is not None:
56735663
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
56745664

56755665
# Save the model

src/diffusers/utils/peft_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,10 @@ def get_peft_kwargs(
158158

159159
if LORA_ADAPTER_METADATA_KEY in peft_state_dict:
160160
metadata = peft_state_dict[LORA_ADAPTER_METADATA_KEY]
161-
if prefix is not None:
162-
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
163-
return metadata
161+
if metadata:
162+
if prefix is not None:
163+
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
164+
return metadata
164165

165166
rank_pattern = {}
166167
alpha_pattern = {}

0 commit comments

Comments
 (0)