@@ -1977,14 +1977,34 @@ def get_alpha_scales(down_weight, alpha_key):
19771977                    "time_projection.1.diff_b" 
19781978                )
19791979
1980-         if  any ("head.head"  in  k  for  k  in  state_dict ):
1981-             converted_state_dict ["proj_out.lora_A.weight" ] =  original_state_dict .pop (
1982-                 f"head.head.{ lora_down_key }  .weight" 
1983-             )
1984-             converted_state_dict ["proj_out.lora_B.weight" ] =  original_state_dict .pop (f"head.head.{ lora_up_key }  .weight" )
1980+         if  any ("head.head"  in  k  for  k  in  original_state_dict ):
1981+             if  any (f"head.head.{ lora_down_key }  .weight"  in  k  for  k  in  state_dict ):
1982+                 converted_state_dict ["proj_out.lora_A.weight" ] =  original_state_dict .pop (
1983+                     f"head.head.{ lora_down_key }  .weight" 
1984+                 )
1985+             if  any (f"head.head.{ lora_up_key }  .weight"  in  k  for  k  in  state_dict ):
1986+                 converted_state_dict ["proj_out.lora_B.weight" ] =  original_state_dict .pop (
1987+                     f"head.head.{ lora_up_key }  .weight" 
1988+                 )
19851989            if  "head.head.diff_b"  in  original_state_dict :
19861990                converted_state_dict ["proj_out.lora_B.bias" ] =  original_state_dict .pop ("head.head.diff_b" )
19871991
1992+             # Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras 
1993+             # This is my (sayakpaul) assumption that this particular key belongs to the down matrix. 
1994+             # Since for this particular LoRA, we don't have the corresponding up matrix, I will use 
1995+             # an identity. 
1996+             if  any ("head.head"  in  k  and  k .endswith (".diff" ) for  k  in  state_dict ):
1997+                 if  f"head.head.{ lora_down_key }  .weight"  in  state_dict :
1998+                     logger .info (
1999+                         f"The state dict seems to be have both `head.head.diff` and `head.head.{ lora_down_key }  .weight` keys, which is unexpected." 
2000+                     )
2001+                 converted_state_dict ["proj_out.lora_A.weight" ] =  original_state_dict .pop ("head.head.diff" )
2002+                 down_matrix_head  =  converted_state_dict ["proj_out.lora_A.weight" ]
2003+                 up_matrix_shape  =  (down_matrix_head .shape [0 ], converted_state_dict ["proj_out.lora_B.bias" ].shape [0 ])
2004+                 converted_state_dict ["proj_out.lora_B.weight" ] =  torch .eye (
2005+                     * up_matrix_shape , dtype = down_matrix_head .dtype , device = down_matrix_head .device 
2006+                 ).T 
2007+ 
19882008        for  text_time  in  ["text_embedding" , "time_embedding" ]:
19892009            if  any (text_time  in  k  for  k  in  original_state_dict ):
19902010                for  b_n  in  [0 , 2 ]:
0 commit comments