Skip to content

Commit b7e24d9

Browse files
committed
pr comments
1 parent 0a7be77 commit b7e24d9

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,8 @@ def get_alpha_scales(down_weight, alpha_key):
18881888

18891889
# Cross-attention
18901890
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1891-
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
1891+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1892+
has_alpha = alpha_key in original_state_dict
18921893
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
18931894
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
18941895

@@ -1898,7 +1899,7 @@ def get_alpha_scales(down_weight, alpha_key):
18981899
if has_alpha:
18991900
down_weight = original_state_dict.pop(original_key_A)
19001901
up_weight = original_state_dict.pop(original_key_B)
1901-
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
1902+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
19021903
converted_state_dict[converted_key_A] = down_weight * scale_down
19031904
converted_state_dict[converted_key_B] = up_weight * scale_up
19041905
else:
@@ -1915,7 +1916,8 @@ def get_alpha_scales(down_weight, alpha_key):
19151916

19161917
if is_i2v_lora:
19171918
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1918-
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
1919+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1920+
has_alpha = alpha_key in original_state_dict
19191921
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
19201922
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
19211923

@@ -1925,7 +1927,7 @@ def get_alpha_scales(down_weight, alpha_key):
19251927
if has_alpha:
19261928
down_weight = original_state_dict.pop(original_key_A)
19271929
up_weight = original_state_dict.pop(original_key_B)
1928-
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
1930+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
19291931
converted_state_dict[converted_key_A] = down_weight * scale_down
19301932
converted_state_dict[converted_key_B] = up_weight * scale_up
19311933
else:
@@ -1942,7 +1944,8 @@ def get_alpha_scales(down_weight, alpha_key):
19421944

19431945
# FFN
19441946
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1945-
has_alpha = f"blocks.{i}.{o}.alpha" in original_state_dict
1947+
alpha_key = f"blocks.{i}.{o}.alpha"
1948+
has_alpha = alpha_key in original_state_dict
19461949
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
19471950
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
19481951

@@ -1952,7 +1955,7 @@ def get_alpha_scales(down_weight, alpha_key):
19521955
if has_alpha:
19531956
down_weight = original_state_dict.pop(original_key_A)
19541957
up_weight = original_state_dict.pop(original_key_B)
1955-
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.{o}.alpha")
1958+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
19561959
converted_state_dict[converted_key_A] = down_weight * scale_down
19571960
converted_state_dict[converted_key_B] = up_weight * scale_up
19581961
else:

0 commit comments

Comments
 (0)