Skip to content

Commit adbb048

Browse files
authored
[LoRA] fix conversion utility so that lora dora loads correctly (#8688)
fix conversion utility so that lora dora loads correctly
1 parent effe4b9 commit adbb048

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
142142
network_alphas = {}
143143

144144
# Check for DoRA-enabled LoRAs.
145-
if any(
146-
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
147-
for k in state_dict
148-
):
145+
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
146+
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
147+
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
148+
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
149149
if is_peft_version("<", "0.9.0"):
150150
raise ValueError(
151151
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
173173
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
174174

175175
# Store DoRA scale if present.
176-
if "dora_scale" in state_dict:
176+
if dora_present_in_unet:
177177
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
178178
unet_state_dict[
179179
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
192192
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
193193

194194
# Store DoRA scale if present.
195-
if "dora_scale" in state_dict:
195+
if dora_present_in_te or dora_present_in_te2:
196196
dora_scale_key_to_replace_te = (
197197
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
198198
)
@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
214214
if len(state_dict) > 0:
215215
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
216216

217-
logger.info("Kohya-style checkpoint detected.")
217+
logger.info("Non-diffusers checkpoint detected.")
218218

219219
# Construct final state dict.
220220
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}

0 commit comments

Comments
 (0)