@@ -1348,3 +1348,53 @@ def process_block(prefix, index, convert_norm):
13481348 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
13491349
13501350 return converted_state_dict
1351+
1352+
1353+ def _convert_non_diffusers_wan_lora_to_diffusers (state_dict ):
1354+ converted_state_dict = {}
1355+ original_state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
1356+
1357+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict })
1358+
1359+ for i in range (num_blocks ):
1360+ # Self-attention
1361+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1362+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = original_state_dict .pop (
1363+ f"blocks.{ i } .self_attn.{ o } .lora_A.weight"
1364+ )
1365+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = original_state_dict .pop (
1366+ f"blocks.{ i } .self_attn.{ o } .lora_B.weight"
1367+ )
1368+
1369+ # Cross-attention
1370+ for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1371+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1372+ f"blocks.{ i } .cross_attn.{ o } .lora_A.weight"
1373+ )
1374+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1375+ f"blocks.{ i } .cross_attn.{ o } .lora_B.weight"
1376+ )
1377+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1378+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1379+ f"blocks.{ i } .cross_attn.{ o } .lora_A.weight"
1380+ )
1381+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1382+ f"blocks.{ i } .cross_attn.{ o } .lora_B.weight"
1383+ )
1384+
1385+ # FFN
1386+ for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
1387+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1388+ f"blocks.{ i } .{ o } .lora_A.weight"
1389+ )
1390+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1391+ f"blocks.{ i } .{ o } .lora_B.weight"
1392+ )
1393+
1394+ if len (original_state_dict ) > 0 :
1395+ raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1396+
1397+ for key in list (converted_state_dict .keys ()):
1398+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1399+
1400+ return state_dict
0 commit comments