Skip to content

Commit 32b1a6f

Browse files
committed
dora
1 parent ba5af5a commit 32b1a6f

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
153153
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
154154
)
155155

156-
# Iterate over all LoRA weights.
157-
all_lora_keys = list(state_dict.keys())
158-
for key in all_lora_keys:
159-
if not key.endswith("lora_down.weight"):
160-
continue
161-
156+
# every down weight has a corresponding up weight and potentially an alpha weight
157+
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
158+
for key in lora_keys:
162159
# Extract LoRA name.
163160
lora_name = key.split(".")[0]
164161

@@ -177,9 +174,12 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
177174
# Store DoRA scale if present.
178175
if dora_present_in_unet:
179176
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
180-
unet_state_dict[
181-
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
182-
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
177+
new_key = diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
178+
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
179+
dora_weight = state_dict.pop(lora_name + ".dora_scale")
180+
if dora_weight.dim() <= 2:
181+
dora_weight = dora_weight.squeeze()
182+
unet_state_dict[new_key] = dora_weight
183183

184184
# Handle text encoder LoRAs.
185185
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -194,18 +194,24 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
194194
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
195195

196196
# Store DoRA scale if present.
197-
if dora_present_in_te or dora_present_in_te2:
197+
if (dora_present_in_te or dora_present_in_te2):
198198
dora_scale_key_to_replace_te = (
199199
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
200200
)
201201
if lora_name.startswith(("lora_te_", "lora_te1_")):
202-
te_state_dict[
203-
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
204-
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
202+
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
203+
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
204+
dora_weight = state_dict.pop(lora_name + ".dora_scale")
205+
if dora_weight.dim() <= 2:
206+
dora_weight = dora_weight.squeeze()
207+
te_state_dict[new_key] = dora_weight
205208
elif lora_name.startswith("lora_te2_"):
206-
te2_state_dict[
207-
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
208-
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
209+
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
210+
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
211+
dora_weight = state_dict.pop(lora_name + ".dora_scale")
212+
if dora_weight.dim() <= 2:
213+
dora_weight = dora_weight.squeeze()
214+
te2_state_dict[new_key] = dora_weight
209215

210216
# Store alpha if present.
211217
if lora_name_alpha in state_dict:
@@ -214,7 +220,8 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
214220

215221
# Check if any keys remain.
216222
if len(state_dict) > 0:
217-
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
223+
all_keys_remaining = sorted(list(state_dict.keys()))
224+
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(all_keys_remaining)}")
218225

219226
logger.info("Non-diffusers checkpoint detected.")
220227

@@ -285,7 +292,7 @@ def _convert_unet_lora_key(key):
285292
pass
286293
else:
287294
pass
288-
295+
289296
return diffusers_name
290297

291298

0 commit comments

Comments
 (0)