Skip to content

Commit 96864fb

Browse files
committed
add alpha
1 parent 3770571 commit 96864fb

File tree

1 file changed

+86
-34
lines changed

1 file changed

+86
-34
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18291829
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
18301830
)
18311831

1832+
def get_alpha_scales(down_weight, alpha_key):
1833+
rank = down_weight.shape[0]
1834+
alpha = original_state_dict.pop(alpha_key).item()
1835+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1836+
scale_down = scale
1837+
scale_up = 1.0
1838+
while scale_down * 2 < scale_up:
1839+
scale_down *= 2
1840+
scale_up /= 2
1841+
return scale_down, scale_up
1842+
1843+
18321844
for key in list(original_state_dict.keys()):
18331845
if key.endswith((".diff", ".diff_b")) and "norm" in key:
18341846
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1848,15 +1860,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18481860
for i in range(min_block, max_block + 1):
18491861
# Self-attention
18501862
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1851-
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1852-
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1853-
if original_key in original_state_dict:
1854-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1863+
has_alpha = f"blocks.{i}.self_attn.{o}.alpha" in original_state_dict
1864+
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1865+
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
18551866

1856-
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1857-
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1858-
if original_key in original_state_dict:
1859-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1867+
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1868+
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
1869+
1870+
if has_alpha:
1871+
down_weight = original_state_dict.pop(original_key_A)
1872+
up_weight = original_state_dict.pop(original_key_B)
1873+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}.alpha")
1874+
converted_state_dict[converted_key_A] = down_weight * scale_down
1875+
converted_state_dict[converted_key_B] = up_weight * scale_up
1876+
1877+
else:
1878+
if original_key_A in original_state_dict:
1879+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1880+
if original_key_B in original_state_dict:
1881+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
18601882

18611883
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
18621884
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1865,15 +1887,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18651887

18661888
# Cross-attention
18671889
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1868-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1869-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1870-
if original_key in original_state_dict:
1871-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1890+
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
1891+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1892+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1893+
1894+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1895+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1896+
1897+
if has_alpha:
1898+
down_weight = original_state_dict.pop(original_key_A)
1899+
up_weight = original_state_dict.pop(original_key_B)
1900+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
1901+
converted_state_dict[converted_key_A] = down_weight * scale_down
1902+
converted_state_dict[converted_key_B] = up_weight * scale_up
1903+
else:
1904+
if original_key_A in original_state_dict:
1905+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
18721906

1873-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1874-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1875-
if original_key in original_state_dict:
1876-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1907+
if original_key_B in original_state_dict:
1908+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
18771909

18781910
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
18791911
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1882,15 +1914,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18821914

18831915
if is_i2v_lora:
18841916
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1885-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1886-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1887-
if original_key in original_state_dict:
1888-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1889-
1890-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1891-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1892-
if original_key in original_state_dict:
1893-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1917+
has_alpha = f"blocks.{i}.cross_attn.{o}.alpha" in original_state_dict
1918+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1919+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1920+
1921+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1922+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1923+
1924+
if has_alpha:
1925+
down_weight = original_state_dict.pop(original_key_A)
1926+
up_weight = original_state_dict.pop(original_key_B)
1927+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}.alpha")
1928+
converted_state_dict[converted_key_A] = down_weight * scale_down
1929+
converted_state_dict[converted_key_B] = up_weight * scale_up
1930+
else:
1931+
if original_key_A in original_state_dict:
1932+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1933+
1934+
if original_key_B in original_state_dict:
1935+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
18941936

18951937
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
18961938
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1899,15 +1941,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18991941

19001942
# FFN
19011943
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1902-
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1903-
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1904-
if original_key in original_state_dict:
1905-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1944+
has_alpha = f"blocks.{i}.{o}.alpha" in original_state_dict
1945+
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
1946+
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
1947+
1948+
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
1949+
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
1950+
1951+
if has_alpha:
1952+
down_weight = original_state_dict.pop(original_key_A)
1953+
up_weight = original_state_dict.pop(original_key_B)
1954+
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.{o}.alpha")
1955+
converted_state_dict[converted_key_A] = down_weight * scale_down
1956+
converted_state_dict[converted_key_B] = up_weight * scale_up
1957+
else:
1958+
if original_key_A in original_state_dict:
1959+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
19061960

1907-
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1908-
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1909-
if original_key in original_state_dict:
1910-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1961+
if original_key_B in original_state_dict:
1962+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
19111963

19121964
original_key = f"blocks.{i}.{o}.diff_b"
19131965
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
@@ -2072,4 +2124,4 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20722124
raise ValueError("Invalid LoRA state dict for LTX-Video.")
20732125
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
20742126
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
2075-
return converted_state_dict
2127+
return converted_state_dict

0 commit comments

Comments
 (0)