@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331833 k .startswith ("time_projection" ) and k .endswith (".weight" ) for k in original_state_dict
18341834 )
18351835
1836+ def get_alpha_scales (down_weight , alpha_key ):
1837+ rank = down_weight .shape [0 ]
1838+ alpha = original_state_dict .pop (alpha_key ).item ()
1839+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840+ scale_down = scale
1841+ scale_up = 1.0
1842+ while scale_down * 2 < scale_up :
1843+ scale_down *= 2
1844+ scale_up /= 2
1845+ return scale_down , scale_up
1846+
18361847 for key in list (original_state_dict .keys ()):
18371848 if key .endswith ((".diff" , ".diff_b" )) and "norm" in key :
18381849 # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18521863 for i in range (min_block , max_block + 1 ):
18531864 # Self-attention
18541865 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1855- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight "
1856- converted_key = f"blocks. { i } .attn1. { c } .lora_A.weight"
1857- if original_key in original_state_dict :
1858- converted_state_dict [ converted_key ] = original_state_dict . pop ( original_key )
1866+ alpha_key = f"blocks.{ i } .self_attn.{ o } .alpha "
1867+ has_alpha = alpha_key in original_state_dict
1868+ original_key_A = f"blocks. { i } .self_attn. { o } . { lora_down_key } .weight"
1869+ converted_key_A = f"blocks. { i } .attn1. { c } .lora_A.weight"
18591870
1860- original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1861- converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1862- if original_key in original_state_dict :
1863- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1871+ original_key_B = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1872+ converted_key_B = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1873+
1874+ if has_alpha :
1875+ down_weight = original_state_dict .pop (original_key_A )
1876+ up_weight = original_state_dict .pop (original_key_B )
1877+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1878+ converted_state_dict [converted_key_A ] = down_weight * scale_down
1879+ converted_state_dict [converted_key_B ] = up_weight * scale_up
1880+
1881+ else :
1882+ if original_key_A in original_state_dict :
1883+ converted_state_dict [converted_key_A ] = original_state_dict .pop (original_key_A )
1884+ if original_key_B in original_state_dict :
1885+ converted_state_dict [converted_key_B ] = original_state_dict .pop (original_key_B )
18641886
18651887 original_key = f"blocks.{ i } .self_attn.{ o } .diff_b"
18661888 converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18691891
18701892 # Cross-attention
18711893 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1872- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1873- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1874- if original_key in original_state_dict :
1875- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1876-
1877- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1878- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1879- if original_key in original_state_dict :
1880- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1894+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1895+ has_alpha = alpha_key in original_state_dict
1896+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1897+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1898+
1899+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1900+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1901+
1902+ if original_key_A in original_state_dict :
1903+ down_weight = original_state_dict .pop (original_key_A )
1904+ converted_state_dict [converted_key_A ] = down_weight
1905+ if original_key_B in original_state_dict :
1906+ up_weight = original_state_dict .pop (original_key_B )
1907+ converted_state_dict [converted_key_B ] = up_weight
1908+ if has_alpha :
1909+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1910+ converted_state_dict [converted_key_A ] *= scale_down
1911+ converted_state_dict [converted_key_B ] *= scale_up
18811912
18821913 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
18831914 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18861917
18871918 if is_i2v_lora :
18881919 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1889- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1890- converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1891- if original_key in original_state_dict :
1892- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1893-
1894- original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1895- converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1896- if original_key in original_state_dict :
1897- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1920+ alpha_key = f"blocks.{ i } .cross_attn.{ o } .alpha"
1921+ has_alpha = alpha_key in original_state_dict
1922+ original_key_A = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1923+ converted_key_A = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1924+
1925+ original_key_B = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1926+ converted_key_B = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1927+
1928+ if original_key_A in original_state_dict :
1929+ down_weight = original_state_dict .pop (original_key_A )
1930+ converted_state_dict [converted_key_A ] = down_weight
1931+ if original_key_B in original_state_dict :
1932+ up_weight = original_state_dict .pop (original_key_B )
1933+ converted_state_dict [converted_key_B ] = up_weight
1934+ if has_alpha :
1935+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1936+ converted_state_dict [converted_key_A ] *= scale_down
1937+ converted_state_dict [converted_key_B ] *= scale_up
18981938
18991939 original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
19001940 converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19031943
19041944 # FFN
19051945 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1906- original_key = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1907- converted_key = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1908- if original_key in original_state_dict :
1909- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1910-
1911- original_key = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1912- converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1913- if original_key in original_state_dict :
1914- converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1946+ alpha_key = f"blocks.{ i } .{ o } .alpha"
1947+ has_alpha = alpha_key in original_state_dict
1948+ original_key_A = f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
1949+ converted_key_A = f"blocks.{ i } .ffn.{ c } .lora_A.weight"
1950+
1951+ original_key_B = f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
1952+ converted_key_B = f"blocks.{ i } .ffn.{ c } .lora_B.weight"
1953+
1954+ if original_key_A in original_state_dict :
1955+ down_weight = original_state_dict .pop (original_key_A )
1956+ converted_state_dict [converted_key_A ] = down_weight
1957+ if original_key_B in original_state_dict :
1958+ up_weight = original_state_dict .pop (original_key_B )
1959+ converted_state_dict [converted_key_B ] = up_weight
1960+ if has_alpha :
1961+ scale_down , scale_up = get_alpha_scales (down_weight , alpha_key )
1962+ converted_state_dict [converted_key_A ] *= scale_down
1963+ converted_state_dict [converted_key_B ] *= scale_up
19151964
19161965 original_key = f"blocks.{ i } .{ o } .diff_b"
19171966 converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.bias"
@@ -2080,6 +2129,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802129
20812130
20822131def _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2132+ has_lora_unet = any (k .startswith ("lora_unet_" ) for k in state_dict )
2133+ if has_lora_unet :
2134+ state_dict = {k .removeprefix ("lora_unet_" ): v for k , v in state_dict .items ()}
2135+
2136+ def convert_key (key : str ) -> str :
2137+ prefix = "transformer_blocks"
2138+ if "." in key :
2139+ base , suffix = key .rsplit ("." , 1 )
2140+ else :
2141+ base , suffix = key , ""
2142+
2143+ start = f"{ prefix } _"
2144+ rest = base [len (start ) :]
2145+
2146+ if "." in rest :
2147+ head , tail = rest .split ("." , 1 )
2148+ tail = "." + tail
2149+ else :
2150+ head , tail = rest , ""
2151+
2152+ # Protected n-grams that must keep their internal underscores
2153+ protected = {
2154+ # pairs
2155+ ("to" , "q" ),
2156+ ("to" , "k" ),
2157+ ("to" , "v" ),
2158+ ("to" , "out" ),
2159+ ("add" , "q" ),
2160+ ("add" , "k" ),
2161+ ("add" , "v" ),
2162+ ("txt" , "mlp" ),
2163+ ("img" , "mlp" ),
2164+ ("txt" , "mod" ),
2165+ ("img" , "mod" ),
2166+ # triplets
2167+ ("add" , "q" , "proj" ),
2168+ ("add" , "k" , "proj" ),
2169+ ("add" , "v" , "proj" ),
2170+ ("to" , "add" , "out" ),
2171+ }
2172+
2173+ prot_by_len = {}
2174+ for ng in protected :
2175+ prot_by_len .setdefault (len (ng ), set ()).add (ng )
2176+
2177+ parts = head .split ("_" )
2178+ merged = []
2179+ i = 0
2180+ lengths_desc = sorted (prot_by_len .keys (), reverse = True )
2181+
2182+ while i < len (parts ):
2183+ matched = False
2184+ for L in lengths_desc :
2185+ if i + L <= len (parts ) and tuple (parts [i : i + L ]) in prot_by_len [L ]:
2186+ merged .append ("_" .join (parts [i : i + L ]))
2187+ i += L
2188+ matched = True
2189+ break
2190+ if not matched :
2191+ merged .append (parts [i ])
2192+ i += 1
2193+
2194+ head_converted = "." .join (merged )
2195+ converted_base = f"{ prefix } .{ head_converted } { tail } "
2196+ return converted_base + (("." + suffix ) if suffix else "" )
2197+
2198+ state_dict = {convert_key (k ): v for k , v in state_dict .items ()}
2199+
20832200 converted_state_dict = {}
20842201 all_keys = list (state_dict .keys ())
20852202 down_key = ".lora_down.weight"
0 commit comments