Skip to content
38 changes: 33 additions & 5 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def load_lora_weights(
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

Expand Down Expand Up @@ -211,6 +211,11 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
Expand Down Expand Up @@ -562,7 +567,8 @@ def load_lora_weights(
unet_config=self.unet.config,
**kwargs,
)
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

Expand Down Expand Up @@ -684,6 +690,11 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
Expand Down Expand Up @@ -1125,7 +1136,13 @@ def load_lora_weights(
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we sometimes add this code to lora_state_dict() and sometimes on load_lora_weights. Is there any reason we do this differently?

Copy link
Member Author

@sayakpaul sayakpaul Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I see what you mean. Will update accordingly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In e08cf74, I have tried to make this more consistent by moving all the "dora_scale" check related code to lora_state_dict().

LMK if that works for you, @yiyixuxu.

warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

Expand Down Expand Up @@ -1587,6 +1604,11 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.

Expand Down Expand Up @@ -1659,7 +1681,7 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)

is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

Expand Down Expand Up @@ -2405,7 +2427,13 @@ def load_lora_weights(
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

Expand Down
20 changes: 13 additions & 7 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
StableDiffusionXLPipeline,
T2IAdapter,
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
load_image,
nightly,
numpy_cosine_similarity_distance,
Expand Down Expand Up @@ -620,14 +622,18 @@ def test_integration_logits_for_dora_lora(self):
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
pipeline.enable_model_cpu_offload()

images = pipeline(
"photo of ohwx dog",
num_inference_steps=10,
generator=torch.manual_seed(0),
output_type="np",
).images
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
images = pipeline(
"photo of ohwx dog",
num_inference_steps=10,
generator=torch.manual_seed(0),
output_type="np",
).images
assert "It seems like you are using a DoRA checkpoint" in cap_logger.out

predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886])
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
assert max_diff < 1e-3
Loading