Skip to content

Commit 9d7e3a2

Browse files
committed
fix
1 parent 7a32ee2 commit 9d7e3a2

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,6 @@ def load_lora_weights(
9999
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
100100
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
101101

102-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
103-
if is_dora_scale_present:
104-
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
105-
logger.warning(warn_msg)
106-
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
107-
108102
is_correct_format = all("lora" in key for key in state_dict.keys())
109103
if not is_correct_format:
110104
raise ValueError("Invalid LoRA checkpoint.")
@@ -217,6 +211,11 @@ def lora_state_dict(
217211
user_agent=user_agent,
218212
allow_pickle=allow_pickle,
219213
)
214+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
215+
if is_dora_scale_present:
216+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
217+
logger.warning(warn_msg)
218+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
220219

221220
network_alphas = None
222221
# TODO: replace it with a method from `state_dict_utils`
@@ -569,15 +568,6 @@ def load_lora_weights(
569568
**kwargs,
570569
)
571570

572-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
573-
print(f"{is_dora_scale_present=}")
574-
if is_dora_scale_present:
575-
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
576-
logger.warning(warn_msg)
577-
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
578-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
579-
print(f"{is_dora_scale_present=}")
580-
581571
is_correct_format = all("lora" in key for key in state_dict.keys())
582572
if not is_correct_format:
583573
raise ValueError("Invalid LoRA checkpoint.")
@@ -700,6 +690,14 @@ def lora_state_dict(
700690
user_agent=user_agent,
701691
allow_pickle=allow_pickle,
702692
)
693+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
694+
print(f"{is_dora_scale_present=}")
695+
if is_dora_scale_present:
696+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
697+
logger.warning(warn_msg)
698+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
699+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
700+
print(f"{is_dora_scale_present=}")
703701

704702
network_alphas = None
705703
# TODO: replace it with a method from `state_dict_utils`
@@ -1609,6 +1607,11 @@ def lora_state_dict(
16091607
user_agent=user_agent,
16101608
allow_pickle=allow_pickle,
16111609
)
1610+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1611+
if is_dora_scale_present:
1612+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1613+
logger.warning(warn_msg)
1614+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
16121615

16131616
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
16141617

@@ -1681,12 +1684,6 @@ def load_lora_weights(
16811684
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
16821685
)
16831686

1684-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1685-
if is_dora_scale_present:
1686-
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1687-
logger.warning(warn_msg)
1688-
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1689-
16901687
is_correct_format = all("lora" in key for key in state_dict.keys())
16911688
if not is_correct_format:
16921689
raise ValueError("Invalid LoRA checkpoint.")

tests/lora/test_lora_layers_sdxl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def test_integration_logits_for_dora_lora(self):
629629

630630
predicted_slice = images[0, -3:, -3:, -1].flatten()
631631
from diffusers.utils.testing_utils import print_tensor_test
632+
632633
print_tensor_test(predicted_slice)
633634
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
634635
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)

0 commit comments

Comments
 (0)