@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1833
1833
k .startswith ("time_projection" ) and k .endswith (".weight" ) for k in original_state_dict
1834
1834
)
1835
1835
1836
+ def get_alpha_scales (down_weight , alpha_key ):
1837
+ rank = down_weight .shape [0 ]
1838
+ alpha = original_state_dict .pop (alpha_key ).item ()
1839
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840
+ scale_down = scale
1841
+ scale_up = 1.0
1842
+ while scale_down * 2 < scale_up :
1843
+ scale_down *= 2
1844
+ scale_up /= 2
1845
+ return scale_down , scale_up
1846
+
1836
1847
for key in list (original_state_dict .keys ()):
1837
1848
if key .endswith ((".diff" , ".diff_b" )) and "norm" in key :
1838
1849
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1852
1863
for i in range (min_block , max_block + 1 ):
1853
1864
# Self-attention
1854
1865
for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1855
- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight "
1856
- converted_key = f"blocks. { i } .attn1. { c } .lora_A.weight"
1857
- if original_key in original_state_dict :
1858
- converted_state_dict [ converted_key ] = original_state_dict . pop ( original_key )
1866
+ alpha_key = f"blocks.{ i } .self_attn.{ o } .alpha "
1867
+ has_alpha = alpha_key in original_state_dict
1868
+ original_key_A = f"blocks. { i } .self_attn. { o } . { lora_down_key } .weight"
1869
+ converted_key_A = f"blocks. { i } .attn1. { c } .lora_A.weight"
1859
1870
1860
- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1861
- converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1862
- if original_key in original_state_dict :
1863
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1871
+ original_key_B = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1872
+ converted_key_B = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1873
+
1874
+ if has_alpha :
1875
+ down_weight = original_state_dict .pop (original_key_A )
1876
+ up_weight = original_state_dict .pop (original_key_B )
1877
+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1878
+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1879
+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1880
+
1881
+ else :
1882
+ if original_key_A in original_state_dict :
1883
+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1884
+ if original_key_B in original_state_dict :
1885
+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
1864
1886
1865
1887
original_key = f"blocks.{ i } .self_attn.{ o } .diff_b"
1866
1888
converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1869
1891
1870
1892
# Cross-attention
1871
1893
for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1872
- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1873
- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1874
- if original_key in original_state_dict :
1875
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1876
-
1877
- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1878
- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1879
- if original_key in original_state_dict :
1880
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1894
+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1895
+ has_alpha = alpha_key in original_state_dict
1896
+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1897
+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1898
+
1899
+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1900
+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1901
+
1902
+ if original_key_A in original_state_dict :
1903
+ down_weight = original_state_dict .pop (original_key_A )
1904
+ converted_state_dict [converted_key_A ] = down_weight
1905
+ if original_key_B in original_state_dict :
1906
+ up_weight = original_state_dict .pop (original_key_B )
1907
+ converted_state_dict [converted_key_B ] = up_weight
1908
+ if has_alpha :
1909
+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1910
+ converted_state_dict [converted_key_A ] *= scale_down
1911
+ converted_state_dict [converted_key_B ] *= scale_up
1881
1912
1882
1913
original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
1883
1914
converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1886
1917
1887
1918
if is_i2v_lora :
1888
1919
for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1889
- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1890
- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1891
- if original_key in original_state_dict :
1892
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1893
-
1894
- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1895
- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1896
- if original_key in original_state_dict :
1897
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1920
+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1921
+ has_alpha = alpha_key in original_state_dict
1922
+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1923
+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1924
+
1925
+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1926
+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1927
+
1928
+ if original_key_A in original_state_dict :
1929
+ down_weight = original_state_dict .pop (original_key_A )
1930
+ converted_state_dict [converted_key_A ] = down_weight
1931
+ if original_key_B in original_state_dict :
1932
+ up_weight = original_state_dict .pop (original_key_B )
1933
+ converted_state_dict [converted_key_B ] = up_weight
1934
+ if has_alpha :
1935
+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1936
+ converted_state_dict [converted_key_A ] *= scale_down
1937
+ converted_state_dict [converted_key_B ] *= scale_up
1898
1938
1899
1939
original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
1900
1940
converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
1903
1943
1904
1944
# FFN
1905
1945
for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1906
- original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1907
- converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1908
- if original_key in original_state_dict :
1909
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1910
-
1911
- original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1912
- converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1913
- if original_key in original_state_dict :
1914
- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1946
+ alpha_key = f"blocks.{ i } .{ o } .alpha"
1947
+ has_alpha = alpha_key in original_state_dict
1948
+ original_key_A = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1949
+ converted_key_A = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1950
+
1951
+ original_key_B = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1952
+ converted_key_B = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1953
+
1954
+ if original_key_A in original_state_dict :
1955
+ down_weight = original_state_dict .pop (original_key_A )
1956
+ converted_state_dict [converted_key_A ] = down_weight
1957
+ if original_key_B in original_state_dict :
1958
+ up_weight = original_state_dict .pop (original_key_B )
1959
+ converted_state_dict [converted_key_B ] = up_weight
1960
+ if has_alpha :
1961
+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1962
+ converted_state_dict [converted_key_A ] *= scale_down
1963
+ converted_state_dict [converted_key_B ] *= scale_up
1915
1964
1916
1965
original_key = f"blocks.{ i } .{ o } .diff_b"
1917
1966
converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.bias"
0 commit comments