@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331833        k .startswith ("time_projection" ) and  k .endswith (".weight" ) for  k  in  original_state_dict 
18341834    )
18351835
1836+     def  get_alpha_scales (down_weight , alpha_key ):
1837+         rank  =  down_weight .shape [0 ]
1838+         alpha  =  original_state_dict .pop (alpha_key ).item ()
1839+         scale  =  alpha  /  rank   # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here 
1840+         scale_down  =  scale 
1841+         scale_up  =  1.0 
1842+         while  scale_down  *  2  <  scale_up :
1843+             scale_down  *=  2 
1844+             scale_up  /=  2 
1845+         return  scale_down , scale_up 
1846+ 
18361847    for  key  in  list (original_state_dict .keys ()):
18371848        if  key .endswith ((".diff" , ".diff_b" )) and  "norm"  in  key :
18381849            # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it 
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18521863    for  i  in  range (min_block , max_block  +  1 ):
18531864        # Self-attention 
18541865        for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1855-             original_key  =  f"blocks.{ i } { o } { lora_down_key } .weight " 
1856-             converted_key  =  f"blocks. { i } .attn1. { c } .lora_A.weight" 
1857-             if   original_key   in   original_state_dict : 
1858-                  converted_state_dict [ converted_key ]  =   original_state_dict . pop ( original_key ) 
1866+             alpha_key  =  f"blocks.{ i } { o } alpha " 
1867+             has_alpha  =  alpha_key   in   original_state_dict 
1868+             original_key_A   =   f"blocks. { i } .self_attn. { o } . { lora_down_key } .weight" 
1869+             converted_key_A   =   f"blocks. { i } .attn1. { c } .lora_A.weight" 
18591870
1860-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1861-             converted_key  =  f"blocks.{ i } { c }  
1862-             if  original_key  in  original_state_dict :
1863-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1871+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1872+             converted_key_B  =  f"blocks.{ i } { c }  
1873+ 
1874+             if  has_alpha :
1875+                 down_weight  =  original_state_dict .pop (original_key_A )
1876+                 up_weight  =  original_state_dict .pop (original_key_B )
1877+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1878+                 converted_state_dict [converted_key_A ] =  down_weight  *  scale_down 
1879+                 converted_state_dict [converted_key_B ] =  up_weight  *  scale_up 
1880+ 
1881+             else :
1882+                 if  original_key_A  in  original_state_dict :
1883+                     converted_state_dict [converted_key_A ] =  original_state_dict .pop (original_key_A )
1884+                 if  original_key_B  in  original_state_dict :
1885+                     converted_state_dict [converted_key_B ] =  original_state_dict .pop (original_key_B )
18641886
18651887            original_key  =  f"blocks.{ i } { o }  
18661888            converted_key  =  f"blocks.{ i } { c }  
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18691891
18701892        # Cross-attention 
18711893        for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1872-             original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1873-             converted_key  =  f"blocks.{ i } { c }  
1874-             if  original_key  in  original_state_dict :
1875-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1876- 
1877-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1878-             converted_key  =  f"blocks.{ i } { c }  
1879-             if  original_key  in  original_state_dict :
1880-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1894+             alpha_key  =  f"blocks.{ i } { o }  
1895+             has_alpha  =  alpha_key  in  original_state_dict 
1896+             original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1897+             converted_key_A  =  f"blocks.{ i } { c }  
1898+ 
1899+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1900+             converted_key_B  =  f"blocks.{ i } { c }  
1901+ 
1902+             if  original_key_A  in  original_state_dict :
1903+                 down_weight  =  original_state_dict .pop (original_key_A )
1904+                 converted_state_dict [converted_key_A ] =  down_weight 
1905+             if  original_key_B  in  original_state_dict :
1906+                 up_weight  =  original_state_dict .pop (original_key_B )
1907+                 converted_state_dict [converted_key_B ] =  up_weight 
1908+             if  has_alpha :
1909+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1910+                 converted_state_dict [converted_key_A ] *=  scale_down 
1911+                 converted_state_dict [converted_key_B ] *=  scale_up 
18811912
18821913            original_key  =  f"blocks.{ i } { o }  
18831914            converted_key  =  f"blocks.{ i } { c }  
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18861917
18871918        if  is_i2v_lora :
18881919            for  o , c  in  zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1889-                 original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1890-                 converted_key  =  f"blocks.{ i } { c }  
1891-                 if  original_key  in  original_state_dict :
1892-                     converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1893- 
1894-                 original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1895-                 converted_key  =  f"blocks.{ i } { c }  
1896-                 if  original_key  in  original_state_dict :
1897-                     converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1920+                 alpha_key  =  f"blocks.{ i } { o }  
1921+                 has_alpha  =  alpha_key  in  original_state_dict 
1922+                 original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1923+                 converted_key_A  =  f"blocks.{ i } { c }  
1924+ 
1925+                 original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1926+                 converted_key_B  =  f"blocks.{ i } { c }  
1927+ 
1928+                 if  original_key_A  in  original_state_dict :
1929+                     down_weight  =  original_state_dict .pop (original_key_A )
1930+                     converted_state_dict [converted_key_A ] =  down_weight 
1931+                 if  original_key_B  in  original_state_dict :
1932+                     up_weight  =  original_state_dict .pop (original_key_B )
1933+                     converted_state_dict [converted_key_B ] =  up_weight 
1934+                 if  has_alpha :
1935+                     scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1936+                     converted_state_dict [converted_key_A ] *=  scale_down 
1937+                     converted_state_dict [converted_key_B ] *=  scale_up 
18981938
18991939                original_key  =  f"blocks.{ i } { o }  
19001940                converted_key  =  f"blocks.{ i } { c }  
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19031943
19041944        # FFN 
19051945        for  o , c  in  zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1906-             original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1907-             converted_key  =  f"blocks.{ i } { c }  
1908-             if  original_key  in  original_state_dict :
1909-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1910- 
1911-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1912-             converted_key  =  f"blocks.{ i } { c }  
1913-             if  original_key  in  original_state_dict :
1914-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1946+             alpha_key  =  f"blocks.{ i } { o }  
1947+             has_alpha  =  alpha_key  in  original_state_dict 
1948+             original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1949+             converted_key_A  =  f"blocks.{ i } { c }  
1950+ 
1951+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1952+             converted_key_B  =  f"blocks.{ i } { c }  
1953+ 
1954+             if  original_key_A  in  original_state_dict :
1955+                 down_weight  =  original_state_dict .pop (original_key_A )
1956+                 converted_state_dict [converted_key_A ] =  down_weight 
1957+             if  original_key_B  in  original_state_dict :
1958+                 up_weight  =  original_state_dict .pop (original_key_B )
1959+                 converted_state_dict [converted_key_B ] =  up_weight 
1960+             if  has_alpha :
1961+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1962+                 converted_state_dict [converted_key_A ] *=  scale_down 
1963+                 converted_state_dict [converted_key_B ] *=  scale_up 
19151964
19161965            original_key  =  f"blocks.{ i } { o }  
19171966            converted_key  =  f"blocks.{ i } { c }  
@@ -2080,6 +2129,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
20802129
20812130
20822131def  _convert_non_diffusers_qwen_lora_to_diffusers (state_dict ):
2132+     has_lora_unet  =  any (k .startswith ("lora_unet_" ) for  k  in  state_dict )
2133+     if  has_lora_unet :
2134+         state_dict  =  {k .removeprefix ("lora_unet_" ): v  for  k , v  in  state_dict .items ()}
2135+ 
2136+         def  convert_key (key : str ) ->  str :
2137+             prefix  =  "transformer_blocks" 
2138+             if  "."  in  key :
2139+                 base , suffix  =  key .rsplit ("." , 1 )
2140+             else :
2141+                 base , suffix  =  key , "" 
2142+ 
2143+             start  =  f"{ prefix }  
2144+             rest  =  base [len (start ) :]
2145+ 
2146+             if  "."  in  rest :
2147+                 head , tail  =  rest .split ("." , 1 )
2148+                 tail  =  "."  +  tail 
2149+             else :
2150+                 head , tail  =  rest , "" 
2151+ 
2152+             # Protected n-grams that must keep their internal underscores 
2153+             protected  =  {
2154+                 # pairs 
2155+                 ("to" , "q" ),
2156+                 ("to" , "k" ),
2157+                 ("to" , "v" ),
2158+                 ("to" , "out" ),
2159+                 ("add" , "q" ),
2160+                 ("add" , "k" ),
2161+                 ("add" , "v" ),
2162+                 ("txt" , "mlp" ),
2163+                 ("img" , "mlp" ),
2164+                 ("txt" , "mod" ),
2165+                 ("img" , "mod" ),
2166+                 # triplets 
2167+                 ("add" , "q" , "proj" ),
2168+                 ("add" , "k" , "proj" ),
2169+                 ("add" , "v" , "proj" ),
2170+                 ("to" , "add" , "out" ),
2171+             }
2172+ 
2173+             prot_by_len  =  {}
2174+             for  ng  in  protected :
2175+                 prot_by_len .setdefault (len (ng ), set ()).add (ng )
2176+ 
2177+             parts  =  head .split ("_" )
2178+             merged  =  []
2179+             i  =  0 
2180+             lengths_desc  =  sorted (prot_by_len .keys (), reverse = True )
2181+ 
2182+             while  i  <  len (parts ):
2183+                 matched  =  False 
2184+                 for  L  in  lengths_desc :
2185+                     if  i  +  L  <=  len (parts ) and  tuple (parts [i  : i  +  L ]) in  prot_by_len [L ]:
2186+                         merged .append ("_" .join (parts [i  : i  +  L ]))
2187+                         i  +=  L 
2188+                         matched  =  True 
2189+                         break 
2190+                 if  not  matched :
2191+                     merged .append (parts [i ])
2192+                     i  +=  1 
2193+ 
2194+             head_converted  =  "." .join (merged )
2195+             converted_base  =  f"{ prefix } { head_converted } { tail }  
2196+             return  converted_base  +  (("."  +  suffix ) if  suffix  else  "" )
2197+ 
2198+         state_dict  =  {convert_key (k ): v  for  k , v  in  state_dict .items ()}
2199+ 
20832200    converted_state_dict  =  {}
20842201    all_keys  =  list (state_dict .keys ())
20852202    down_key  =  ".lora_down.weight" 
0 commit comments