@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
13461346 return converted_state_dict
13471347
13481348
1349+ def _convert_fal_kontext_lora_to_diffusers (original_state_dict ):
1350+ converted_state_dict = {}
1351+ original_state_dict_keys = list (original_state_dict .keys ())
1352+ num_layers = 19
1353+ num_single_layers = 38
1354+ inner_dim = 3072
1355+ mlp_ratio = 4.0
1356+
1357+ # double transformer blocks
1358+ for i in range (num_layers ):
1359+ block_prefix = f"transformer_blocks.{ i } ."
1360+ original_block_prefix = "base_model.model."
1361+
1362+ for lora_key in ["lora_A" , "lora_B" ]:
1363+ # norms
1364+ converted_state_dict [f"{ block_prefix } norm1.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1365+ f"{ original_block_prefix } double_blocks.{ i } .img_mod.lin.{ lora_key } .weight"
1366+ )
1367+ if f"double_blocks.{ i } .img_mod.lin.{ lora_key } .bias" in original_state_dict_keys :
1368+ converted_state_dict [f"{ block_prefix } norm1.linear.{ lora_key } .bias" ] = original_state_dict .pop (
1369+ f"{ original_block_prefix } double_blocks.{ i } .img_mod.lin.{ lora_key } .bias"
1370+ )
1371+
1372+ converted_state_dict [f"{ block_prefix } norm1_context.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1373+ f"{ original_block_prefix } double_blocks.{ i } .txt_mod.lin.{ lora_key } .weight"
1374+ )
1375+
1376+ # Q, K, V
1377+ if lora_key == "lora_A" :
1378+ sample_lora_weight = original_state_dict .pop (
1379+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .weight"
1380+ )
1381+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1382+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1383+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([sample_lora_weight ])
1384+
1385+ context_lora_weight = original_state_dict .pop (
1386+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .weight"
1387+ )
1388+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .weight" ] = torch .cat (
1389+ [context_lora_weight ]
1390+ )
1391+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .weight" ] = torch .cat (
1392+ [context_lora_weight ]
1393+ )
1394+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .weight" ] = torch .cat (
1395+ [context_lora_weight ]
1396+ )
1397+ else :
1398+ sample_q , sample_k , sample_v = torch .chunk (
1399+ original_state_dict .pop (
1400+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .weight"
1401+ ),
1402+ 3 ,
1403+ dim = 0 ,
1404+ )
1405+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([sample_q ])
1406+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([sample_k ])
1407+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([sample_v ])
1408+
1409+ context_q , context_k , context_v = torch .chunk (
1410+ original_state_dict .pop (
1411+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .weight"
1412+ ),
1413+ 3 ,
1414+ dim = 0 ,
1415+ )
1416+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .weight" ] = torch .cat ([context_q ])
1417+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .weight" ] = torch .cat ([context_k ])
1418+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .weight" ] = torch .cat ([context_v ])
1419+
1420+ if f"double_blocks.{ i } .img_attn.qkv.{ lora_key } .bias" in original_state_dict_keys :
1421+ sample_q_bias , sample_k_bias , sample_v_bias = torch .chunk (
1422+ original_state_dict .pop (f"{ original_block_prefix } double_blocks.{ i } .img_attn.qkv.{ lora_key } .bias" ),
1423+ 3 ,
1424+ dim = 0 ,
1425+ )
1426+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([sample_q_bias ])
1427+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([sample_k_bias ])
1428+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([sample_v_bias ])
1429+
1430+ if f"double_blocks.{ i } .txt_attn.qkv.{ lora_key } .bias" in original_state_dict_keys :
1431+ context_q_bias , context_k_bias , context_v_bias = torch .chunk (
1432+ original_state_dict .pop (f"{ original_block_prefix } double_blocks.{ i } .txt_attn.qkv.{ lora_key } .bias" ),
1433+ 3 ,
1434+ dim = 0 ,
1435+ )
1436+ converted_state_dict [f"{ block_prefix } attn.add_q_proj.{ lora_key } .bias" ] = torch .cat ([context_q_bias ])
1437+ converted_state_dict [f"{ block_prefix } attn.add_k_proj.{ lora_key } .bias" ] = torch .cat ([context_k_bias ])
1438+ converted_state_dict [f"{ block_prefix } attn.add_v_proj.{ lora_key } .bias" ] = torch .cat ([context_v_bias ])
1439+
1440+ # ff img_mlp
1441+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.{ lora_key } .weight" ] = original_state_dict .pop (
1442+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .weight"
1443+ )
1444+ if f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .bias" in original_state_dict_keys :
1445+ converted_state_dict [f"{ block_prefix } ff.net.0.proj.{ lora_key } .bias" ] = original_state_dict .pop (
1446+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.0.{ lora_key } .bias"
1447+ )
1448+
1449+ converted_state_dict [f"{ block_prefix } ff.net.2.{ lora_key } .weight" ] = original_state_dict .pop (
1450+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .weight"
1451+ )
1452+ if f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .bias" in original_state_dict_keys :
1453+ converted_state_dict [f"{ block_prefix } ff.net.2.{ lora_key } .bias" ] = original_state_dict .pop (
1454+ f"{ original_block_prefix } double_blocks.{ i } .img_mlp.2.{ lora_key } .bias"
1455+ )
1456+
1457+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.{ lora_key } .weight" ] = original_state_dict .pop (
1458+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .weight"
1459+ )
1460+ if f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .bias" in original_state_dict_keys :
1461+ converted_state_dict [f"{ block_prefix } ff_context.net.0.proj.{ lora_key } .bias" ] = original_state_dict .pop (
1462+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.0.{ lora_key } .bias"
1463+ )
1464+
1465+ converted_state_dict [f"{ block_prefix } ff_context.net.2.{ lora_key } .weight" ] = original_state_dict .pop (
1466+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .weight"
1467+ )
1468+ if f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .bias" in original_state_dict_keys :
1469+ converted_state_dict [f"{ block_prefix } ff_context.net.2.{ lora_key } .bias" ] = original_state_dict .pop (
1470+ f"{ original_block_prefix } double_blocks.{ i } .txt_mlp.2.{ lora_key } .bias"
1471+ )
1472+
1473+ # output projections.
1474+ converted_state_dict [f"{ block_prefix } attn.to_out.0.{ lora_key } .weight" ] = original_state_dict .pop (
1475+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .weight"
1476+ )
1477+ if f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .bias" in original_state_dict_keys :
1478+ converted_state_dict [f"{ block_prefix } attn.to_out.0.{ lora_key } .bias" ] = original_state_dict .pop (
1479+ f"{ original_block_prefix } double_blocks.{ i } .img_attn.proj.{ lora_key } .bias"
1480+ )
1481+ converted_state_dict [f"{ block_prefix } attn.to_add_out.{ lora_key } .weight" ] = original_state_dict .pop (
1482+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .weight"
1483+ )
1484+ if f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .bias" in original_state_dict_keys :
1485+ converted_state_dict [f"{ block_prefix } attn.to_add_out.{ lora_key } .bias" ] = original_state_dict .pop (
1486+ f"{ original_block_prefix } double_blocks.{ i } .txt_attn.proj.{ lora_key } .bias"
1487+ )
1488+
1489+ # single transformer blocks
1490+ for i in range (num_single_layers ):
1491+ block_prefix = f"single_transformer_blocks.{ i } ."
1492+
1493+ for lora_key in ["lora_A" , "lora_B" ]:
1494+ # norm.linear <- single_blocks.0.modulation.lin
1495+ converted_state_dict [f"{ block_prefix } norm.linear.{ lora_key } .weight" ] = original_state_dict .pop (
1496+ f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .weight"
1497+ )
1498+ if f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .bias" in original_state_dict_keys :
1499+ converted_state_dict [f"{ block_prefix } norm.linear.{ lora_key } .bias" ] = original_state_dict .pop (
1500+ f"{ original_block_prefix } single_blocks.{ i } .modulation.lin.{ lora_key } .bias"
1501+ )
1502+
1503+ # Q, K, V, mlp
1504+ mlp_hidden_dim = int (inner_dim * mlp_ratio )
1505+ split_size = (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
1506+
1507+ if lora_key == "lora_A" :
1508+ lora_weight = original_state_dict .pop (
1509+ f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .weight"
1510+ )
1511+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1512+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1513+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1514+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .weight" ] = torch .cat ([lora_weight ])
1515+
1516+ if f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" in original_state_dict_keys :
1517+ lora_bias = original_state_dict .pop (f"single_blocks.{ i } .linear1.{ lora_key } .bias" )
1518+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1519+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1520+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1521+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .bias" ] = torch .cat ([lora_bias ])
1522+ else :
1523+ q , k , v , mlp = torch .split (
1524+ original_state_dict .pop (f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .weight" ),
1525+ split_size ,
1526+ dim = 0 ,
1527+ )
1528+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .weight" ] = torch .cat ([q ])
1529+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .weight" ] = torch .cat ([k ])
1530+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .weight" ] = torch .cat ([v ])
1531+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .weight" ] = torch .cat ([mlp ])
1532+
1533+ if f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" in original_state_dict_keys :
1534+ q_bias , k_bias , v_bias , mlp_bias = torch .split (
1535+ original_state_dict .pop (f"{ original_block_prefix } single_blocks.{ i } .linear1.{ lora_key } .bias" ),
1536+ split_size ,
1537+ dim = 0 ,
1538+ )
1539+ converted_state_dict [f"{ block_prefix } attn.to_q.{ lora_key } .bias" ] = torch .cat ([q_bias ])
1540+ converted_state_dict [f"{ block_prefix } attn.to_k.{ lora_key } .bias" ] = torch .cat ([k_bias ])
1541+ converted_state_dict [f"{ block_prefix } attn.to_v.{ lora_key } .bias" ] = torch .cat ([v_bias ])
1542+ converted_state_dict [f"{ block_prefix } proj_mlp.{ lora_key } .bias" ] = torch .cat ([mlp_bias ])
1543+
1544+ # output projections.
1545+ converted_state_dict [f"{ block_prefix } proj_out.{ lora_key } .weight" ] = original_state_dict .pop (
1546+ f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .weight"
1547+ )
1548+ if f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .bias" in original_state_dict_keys :
1549+ converted_state_dict [f"{ block_prefix } proj_out.{ lora_key } .bias" ] = original_state_dict .pop (
1550+ f"{ original_block_prefix } single_blocks.{ i } .linear2.{ lora_key } .bias"
1551+ )
1552+
1553+ for lora_key in ["lora_A" , "lora_B" ]:
1554+ converted_state_dict [f"proj_out.{ lora_key } .weight" ] = original_state_dict .pop (
1555+ f"{ original_block_prefix } final_layer.linear.{ lora_key } .weight"
1556+ )
1557+ if f"{ original_block_prefix } final_layer.linear.{ lora_key } .bias" in original_state_dict_keys :
1558+ converted_state_dict [f"proj_out.{ lora_key } .bias" ] = original_state_dict .pop (
1559+ f"{ original_block_prefix } final_layer.linear.{ lora_key } .bias"
1560+ )
1561+
1562+ if len (original_state_dict ) > 0 :
1563+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
1564+
1565+ for key in list (converted_state_dict .keys ()):
1566+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1567+
1568+ return converted_state_dict
1569+
1570+
13491571def _convert_hunyuan_video_lora_to_diffusers (original_state_dict ):
13501572 converted_state_dict = {k : original_state_dict .pop (k ) for k in list (original_state_dict .keys ())}
13511573
0 commit comments