@@ -1829,6 +1829,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18291829 k .startswith ("time_projection" ) and k .endswith (".weight" ) for k in original_state_dict
18301830 )
18311831
1832+ def get_alpha_scales (down_weight , alpha_key ):
1833+ rank = down_weight .shape [0 ]
1834+ alpha = original_state_dict .pop (alpha_key ).item ()
1835+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1836+ scale_down = scale
1837+ scale_up = 1.0
1838+ while scale_down * 2 < scale_up :
1839+ scale_down *= 2
1840+ scale_up /= 2
1841+ return scale_down , scale_up
1842+
1843+
18321844 for key in list (original_state_dict .keys ()):
18331845 if key .endswith ((".diff" , ".diff_b" )) and "norm" in key :
18341846 # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1848,15 +1860,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18481860 for i in range (min_block , max_block + 1 ):
18491861 # Self-attention
18501862 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1851- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
1852- converted_key = f"blocks.{ i } .attn1.{ c } .lora_A.weight"
1853- if original_key in original_state_dict :
1854- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1863+ has_alpha = f"blocks.{ i } .self_attn.{ o } .alpha" in original_state_dict
1864+ original_key_A = f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
1865+ converted_key_A = f"blocks.{ i } .attn1.{ c } .lora_A.weight"
18551866
1856- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1857- converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1858- if original_key in original_state_dict :
1859- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1867+ original_key_B = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1868+ converted_key_B = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1869+
1870+ if has_alpha :
1871+ down_weight = original_state_dict .pop (original_key_A )
1872+ up_weight = original_state_dict .pop (original_key_B )
1873+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks.{ i } .self_attn.{ o } .alpha" )
1874+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1875+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1876+
1877+ else :
1878+ if original_key_A in original_state_dict :
1879+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1880+ if original_key_B in original_state_dict :
1881+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
18601882
18611883 original_key = f"blocks.{ i } .self_attn.{ o } .diff_b"
18621884 converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.bias"
@@ -1865,15 +1887,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18651887
18661888 # Cross-attention
18671889 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1868- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1869- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1870- if original_key in original_state_dict :
1871- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1890+ has_alpha = f"blocks.{ i } .cross_attn.{ o } .alpha" in original_state_dict
1891+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1892+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1893+
1894+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1895+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1896+
1897+ if has_alpha :
1898+ down_weight = original_state_dict .pop (original_key_A )
1899+ up_weight = original_state_dict .pop (original_key_B )
1900+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks.{ i } .cross_attn.{ o } .alpha" )
1901+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1902+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1903+ else :
1904+ if original_key_A in original_state_dict :
1905+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
18721906
1873- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1874- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1875- if original_key in original_state_dict :
1876- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1907+ if original_key_B in original_state_dict :
1908+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
18771909
18781910 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
18791911 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1882,15 +1914,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18821914
18831915 if is_i2v_lora :
18841916 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1885- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1886- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1887- if original_key in original_state_dict :
1888- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1889-
1890- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1891- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1892- if original_key in original_state_dict :
1893- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1917+ has_alpha = f"blocks.{ i } .cross_attn.{ o } .alpha" in original_state_dict
1918+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1919+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1920+
1921+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1922+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1923+
1924+ if has_alpha :
1925+ down_weight = original_state_dict .pop (original_key_A )
1926+ up_weight = original_state_dict .pop (original_key_B )
1927+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks.{ i } .cross_attn.{ o } .alpha" )
1928+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1929+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1930+ else :
1931+ if original_key_A in original_state_dict :
1932+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1933+
1934+ if original_key_B in original_state_dict :
1935+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
18941936
18951937 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
18961938 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1899,15 +1941,25 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18991941
19001942 # FFN
19011943 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1902- original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1903- converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1904- if original_key in original_state_dict :
1905- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1944+ has_alpha = f"blocks.{ i } .{ o } .alpha" in original_state_dict
1945+ original_key_A = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1946+ converted_key_A = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1947+
1948+ original_key_B = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1949+ converted_key_B = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1950+
1951+ if has_alpha :
1952+ down_weight = original_state_dict .pop (original_key_A )
1953+ up_weight = original_state_dict .pop (original_key_B )
1954+ scale_down , scale_up = get_alpha_scales (down_weight , f"blocks.{ i } .{ o } .alpha" )
1955+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1956+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1957+ else :
1958+ if original_key_A in original_state_dict :
1959+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
19061960
1907- original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1908- converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1909- if original_key in original_state_dict :
1910- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1961+ if original_key_B in original_state_dict :
1962+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
19111963
19121964 original_key = f"blocks.{ i } .{ o } .diff_b"
19131965 converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.bias"
@@ -2072,4 +2124,4 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20722124 raise ValueError ("Invalid LoRA state dict for LTX-Video." )
20732125 converted_state_dict = {k .removeprefix (f"{ non_diffusers_prefix } ." ): v for k , v in state_dict .items ()}
20742126 converted_state_dict = {f"transformer.{ k } " : v for k , v in converted_state_dict .items ()}
2075- return converted_state_dict
2127+ return converted_state_dict
0 commit comments