Skip to content

Commit bcb0924

Browse files
committed
pr comments
1 parent b7e24d9 commit bcb0924

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,18 +1896,16 @@ def get_alpha_scales(down_weight, alpha_key):
18961896
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
18971897
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
18981898

1899-
if has_alpha:
1899+
if original_key_A in original_state_dict:
19001900
down_weight = original_state_dict.pop(original_key_A)
1901+
converted_state_dict[converted_key_A] = down_weight
1902+
if original_key_B in original_state_dict:
19011903
up_weight = original_state_dict.pop(original_key_B)
1904+
converted_state_dict[converted_key_B] = up_weight
1905+
if has_alpha:
19021906
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1903-
converted_state_dict[converted_key_A] = down_weight * scale_down
1904-
converted_state_dict[converted_key_B] = up_weight * scale_up
1905-
else:
1906-
if original_key_A in original_state_dict:
1907-
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1908-
1909-
if original_key_B in original_state_dict:
1910-
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
1907+
converted_state_dict[converted_key_A] *= scale_down
1908+
converted_state_dict[converted_key_B] *= scale_up
19111909

19121910
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
19131911
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1924,18 +1922,16 @@ def get_alpha_scales(down_weight, alpha_key):
19241922
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
19251923
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
19261924

1927-
if has_alpha:
1925+
if original_key_A in original_state_dict:
19281926
down_weight = original_state_dict.pop(original_key_A)
1927+
converted_state_dict[converted_key_A] = down_weight
1928+
if original_key_B in original_state_dict:
19291929
up_weight = original_state_dict.pop(original_key_B)
1930+
converted_state_dict[converted_key_B] = up_weight
1931+
if has_alpha:
19301932
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1931-
converted_state_dict[converted_key_A] = down_weight * scale_down
1932-
converted_state_dict[converted_key_B] = up_weight * scale_up
1933-
else:
1934-
if original_key_A in original_state_dict:
1935-
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1936-
1937-
if original_key_B in original_state_dict:
1938-
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
1933+
converted_state_dict[converted_key_A] *= scale_down
1934+
converted_state_dict[converted_key_B] *= scale_up
19391935

19401936
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
19411937
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1952,18 +1948,16 @@ def get_alpha_scales(down_weight, alpha_key):
19521948
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
19531949
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
19541950

1955-
if has_alpha:
1951+
if original_key_A in original_state_dict:
19561952
down_weight = original_state_dict.pop(original_key_A)
1953+
converted_state_dict[converted_key_A] = down_weight
1954+
if original_key_B in original_state_dict:
19571955
up_weight = original_state_dict.pop(original_key_B)
1956+
converted_state_dict[converted_key_B] = up_weight
1957+
if has_alpha:
19581958
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1959-
converted_state_dict[converted_key_A] = down_weight * scale_down
1960-
converted_state_dict[converted_key_B] = up_weight * scale_up
1961-
else:
1962-
if original_key_A in original_state_dict:
1963-
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1964-
1965-
if original_key_B in original_state_dict:
1966-
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
1959+
converted_state_dict[converted_key_A] *= scale_down
1960+
converted_state_dict[converted_key_B] *= scale_up
19671961

19681962
original_key = f"blocks.{i}.{o}.diff_b"
19691963
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"

0 commit comments

Comments
 (0)