Skip to content

Commit 9d3d68f

Browse files
committed
handle dora.
1 parent 1154243 commit 9d3d68f

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ def load_lora_weights(
9999
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
100100
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
101101

102-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
102+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
103+
if is_dora_scale_present:
104+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
105+
logger.warning(warn_msg)
106+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
107+
108+
is_correct_format = all("lora" in key for key in state_dict.keys())
103109
if not is_correct_format:
104110
raise ValueError("Invalid LoRA checkpoint.")
105111

@@ -562,7 +568,14 @@ def load_lora_weights(
562568
unet_config=self.unet.config,
563569
**kwargs,
564570
)
565-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
571+
572+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
573+
if is_dora_scale_present:
574+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
575+
logger.warning(warn_msg)
576+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
577+
578+
is_correct_format = all("lora" in key for key in state_dict.keys())
566579
if not is_correct_format:
567580
raise ValueError("Invalid LoRA checkpoint.")
568581

@@ -1125,7 +1138,13 @@ def load_lora_weights(
11251138
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11261139
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
11271140

1128-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1141+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1142+
if is_dora_scale_present:
1143+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1144+
logger.warning(warn_msg)
1145+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1146+
1147+
is_correct_format = all("lora" in key for key in state_dict.keys())
11291148
if not is_correct_format:
11301149
raise ValueError("Invalid LoRA checkpoint.")
11311150

@@ -1659,7 +1678,13 @@ def load_lora_weights(
16591678
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
16601679
)
16611680

1662-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1681+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1682+
if is_dora_scale_present:
1683+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1684+
logger.warning(warn_msg)
1685+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1686+
1687+
is_correct_format = all("lora" in key for key in state_dict.keys())
16631688
if not is_correct_format:
16641689
raise ValueError("Invalid LoRA checkpoint.")
16651690

@@ -2405,7 +2430,13 @@ def load_lora_weights(
24052430
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
24062431
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
24072432

2408-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
2433+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2434+
if is_dora_scale_present:
2435+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2436+
logger.warning(warn_msg)
2437+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2438+
2439+
is_correct_format = all("lora" in key for key in state_dict.keys())
24092440
if not is_correct_format:
24102441
raise ValueError("Invalid LoRA checkpoint.")
24112442

0 commit comments

Comments
 (0)