@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331833        k .startswith ("time_projection" ) and  k .endswith (".weight" ) for  k  in  original_state_dict 
18341834    )
18351835
1836+     def  get_alpha_scales (down_weight , alpha_key ):
1837+         rank  =  down_weight .shape [0 ]
1838+         alpha  =  original_state_dict .pop (alpha_key ).item ()
1839+         scale  =  alpha  /  rank   # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here 
1840+         scale_down  =  scale 
1841+         scale_up  =  1.0 
1842+         while  scale_down  *  2  <  scale_up :
1843+             scale_down  *=  2 
1844+             scale_up  /=  2 
1845+         return  scale_down , scale_up 
1846+ 
18361847    for  key  in  list (original_state_dict .keys ()):
18371848        if  key .endswith ((".diff" , ".diff_b" )) and  "norm"  in  key :
18381849            # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it 
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18521863    for  i  in  range (min_block , max_block  +  1 ):
18531864        # Self-attention 
18541865        for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1855-             original_key  =  f"blocks.{ i } { o } { lora_down_key } .weight " 
1856-             converted_key  =  f"blocks. { i } .attn1. { c } .lora_A.weight" 
1857-             if   original_key   in   original_state_dict : 
1858-                  converted_state_dict [ converted_key ]  =   original_state_dict . pop ( original_key ) 
1866+             alpha_key  =  f"blocks.{ i } { o } alpha " 
1867+             has_alpha  =  alpha_key   in   original_state_dict 
1868+             original_key_A   =   f"blocks. { i } .self_attn. { o } . { lora_down_key } .weight" 
1869+             converted_key_A   =   f"blocks. { i } .attn1. { c } .lora_A.weight" 
18591870
1860-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1861-             converted_key  =  f"blocks.{ i } { c }  
1862-             if  original_key  in  original_state_dict :
1863-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1871+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1872+             converted_key_B  =  f"blocks.{ i } { c }  
1873+ 
1874+             if  has_alpha :
1875+                 down_weight  =  original_state_dict .pop (original_key_A )
1876+                 up_weight  =  original_state_dict .pop (original_key_B )
1877+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1878+                 converted_state_dict [converted_key_A ] =  down_weight  *  scale_down 
1879+                 converted_state_dict [converted_key_B ] =  up_weight  *  scale_up 
1880+ 
1881+             else :
1882+                 if  original_key_A  in  original_state_dict :
1883+                     converted_state_dict [converted_key_A ] =  original_state_dict .pop (original_key_A )
1884+                 if  original_key_B  in  original_state_dict :
1885+                     converted_state_dict [converted_key_B ] =  original_state_dict .pop (original_key_B )
18641886
18651887            original_key  =  f"blocks.{ i } { o }  
18661888            converted_key  =  f"blocks.{ i } { c }  
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18691891
18701892        # Cross-attention 
18711893        for  o , c  in  zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1872-             original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1873-             converted_key  =  f"blocks.{ i } { c }  
1874-             if  original_key  in  original_state_dict :
1875-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1876- 
1877-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1878-             converted_key  =  f"blocks.{ i } { c }  
1879-             if  original_key  in  original_state_dict :
1880-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1894+             alpha_key  =  f"blocks.{ i } { o }  
1895+             has_alpha  =  alpha_key  in  original_state_dict 
1896+             original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1897+             converted_key_A  =  f"blocks.{ i } { c }  
1898+ 
1899+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1900+             converted_key_B  =  f"blocks.{ i } { c }  
1901+ 
1902+             if  original_key_A  in  original_state_dict :
1903+                 down_weight  =  original_state_dict .pop (original_key_A )
1904+                 converted_state_dict [converted_key_A ] =  down_weight 
1905+             if  original_key_B  in  original_state_dict :
1906+                 up_weight  =  original_state_dict .pop (original_key_B )
1907+                 converted_state_dict [converted_key_B ] =  up_weight 
1908+             if  has_alpha :
1909+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1910+                 converted_state_dict [converted_key_A ] *=  scale_down 
1911+                 converted_state_dict [converted_key_B ] *=  scale_up 
18811912
18821913            original_key  =  f"blocks.{ i } { o }  
18831914            converted_key  =  f"blocks.{ i } { c }  
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18861917
18871918        if  is_i2v_lora :
18881919            for  o , c  in  zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1889-                 original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1890-                 converted_key  =  f"blocks.{ i } { c }  
1891-                 if  original_key  in  original_state_dict :
1892-                     converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1893- 
1894-                 original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1895-                 converted_key  =  f"blocks.{ i } { c }  
1896-                 if  original_key  in  original_state_dict :
1897-                     converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1920+                 alpha_key  =  f"blocks.{ i } { o }  
1921+                 has_alpha  =  alpha_key  in  original_state_dict 
1922+                 original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1923+                 converted_key_A  =  f"blocks.{ i } { c }  
1924+ 
1925+                 original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1926+                 converted_key_B  =  f"blocks.{ i } { c }  
1927+ 
1928+                 if  original_key_A  in  original_state_dict :
1929+                     down_weight  =  original_state_dict .pop (original_key_A )
1930+                     converted_state_dict [converted_key_A ] =  down_weight 
1931+                 if  original_key_B  in  original_state_dict :
1932+                     up_weight  =  original_state_dict .pop (original_key_B )
1933+                     converted_state_dict [converted_key_B ] =  up_weight 
1934+                 if  has_alpha :
1935+                     scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1936+                     converted_state_dict [converted_key_A ] *=  scale_down 
1937+                     converted_state_dict [converted_key_B ] *=  scale_up 
18981938
18991939                original_key  =  f"blocks.{ i } { o }  
19001940                converted_key  =  f"blocks.{ i } { c }  
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19031943
19041944        # FFN 
19051945        for  o , c  in  zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1906-             original_key  =  f"blocks.{ i } { o } { lora_down_key }  
1907-             converted_key  =  f"blocks.{ i } { c }  
1908-             if  original_key  in  original_state_dict :
1909-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1910- 
1911-             original_key  =  f"blocks.{ i } { o } { lora_up_key }  
1912-             converted_key  =  f"blocks.{ i } { c }  
1913-             if  original_key  in  original_state_dict :
1914-                 converted_state_dict [converted_key ] =  original_state_dict .pop (original_key )
1946+             alpha_key  =  f"blocks.{ i } { o }  
1947+             has_alpha  =  alpha_key  in  original_state_dict 
1948+             original_key_A  =  f"blocks.{ i } { o } { lora_down_key }  
1949+             converted_key_A  =  f"blocks.{ i } { c }  
1950+ 
1951+             original_key_B  =  f"blocks.{ i } { o } { lora_up_key }  
1952+             converted_key_B  =  f"blocks.{ i } { c }  
1953+ 
1954+             if  original_key_A  in  original_state_dict :
1955+                 down_weight  =  original_state_dict .pop (original_key_A )
1956+                 converted_state_dict [converted_key_A ] =  down_weight 
1957+             if  original_key_B  in  original_state_dict :
1958+                 up_weight  =  original_state_dict .pop (original_key_B )
1959+                 converted_state_dict [converted_key_B ] =  up_weight 
1960+             if  has_alpha :
1961+                 scale_down , scale_up  =  get_alpha_scales (down_weight , alpha_key )
1962+                 converted_state_dict [converted_key_A ] *=  scale_down 
1963+                 converted_state_dict [converted_key_B ] *=  scale_up 
19151964
19161965            original_key  =  f"blocks.{ i } { o }  
19171966            converted_key  =  f"blocks.{ i } { c }  
0 commit comments