7575 "stable_cascade_stage_b" : "down_blocks.1.0.channelwise.0.weight" ,
7676 "stable_cascade_stage_c" : "clip_txt_mapper.weight" ,
7777 "sd3" : "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias" ,
78+ "sd35_large" : "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight" ,
7879 "animatediff" : "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe" ,
7980 "animatediff_v2" : "mid_block.motion_modules.0.temporal_transformer.norm.bias" ,
8081 "animatediff_sdxl_beta" : "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight" ,
113114 "sd3" : {
114115 "pretrained_model_name_or_path" : "stabilityai/stable-diffusion-3-medium-diffusers" ,
115116 },
117+ "sd35_large" : {
118+ "pretrained_model_name_or_path" : "stabilityai/stable-diffusion-3.5-large" ,
119+ },
116120 "animatediff_v1" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5" },
117121 "animatediff_v2" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5-2" },
118122 "animatediff_v3" : {"pretrained_model_name_or_path" : "guoyww/animatediff-motion-adapter-v1-5-3" },
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
504508 ):
505509 model_type = "stable_cascade_stage_b"
506510
507- elif CHECKPOINT_KEY_NAMES ["sd3" ] in checkpoint :
511+ elif CHECKPOINT_KEY_NAMES ["sd3" ] in checkpoint and checkpoint [ CHECKPOINT_KEY_NAMES [ "sd3" ]]. shape [ - 1 ] == 9216 :
508512 model_type = "sd3"
509513
514+ elif CHECKPOINT_KEY_NAMES ["sd35_large" ] in checkpoint :
515+ model_type = "sd35_large"
516+
510517 elif CHECKPOINT_KEY_NAMES ["animatediff" ] in checkpoint :
511518 if CHECKPOINT_KEY_NAMES ["animatediff_scribble" ] in checkpoint :
512519 model_type = "animatediff_scribble"
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
16701677 return new_weight
16711678
16721679
1680+ def get_attn2_layers (state_dict ):
1681+ attn2_layers = []
1682+ for key in state_dict .keys ():
1683+ if "attn2." in key :
1684+ # Extract the layer number from the key
1685+ layer_num = int (key .split ("." )[1 ])
1686+ attn2_layers .append (layer_num )
1687+
1688+ return tuple (sorted (set (attn2_layers )))
1689+
1690+
1691+ def get_caption_projection_dim (state_dict ):
1692+ caption_projection_dim = state_dict ["context_embedder.weight" ].shape [0 ]
1693+ return caption_projection_dim
1694+
1695+
16731696def convert_sd3_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
16741697 converted_state_dict = {}
16751698 keys = list (checkpoint .keys ())
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
16781701 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
16791702
16801703 num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "joint_blocks" in k ))[- 1 ] + 1 # noqa: C401
1681- caption_projection_dim = 1536
1704+ dual_attention_layers = get_attn2_layers (checkpoint )
1705+
1706+ caption_projection_dim = get_caption_projection_dim (checkpoint )
1707+ has_qk_norm = any ("ln_q" in key for key in checkpoint .keys ())
16821708
16831709 # Positional and patch embeddings.
16841710 converted_state_dict ["pos_embed.pos_embed" ] = checkpoint .pop ("pos_embed" )
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17351761 converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.weight" ] = torch .cat ([context_v ])
17361762 converted_state_dict [f"transformer_blocks.{ i } .attn.add_v_proj.bias" ] = torch .cat ([context_v_bias ])
17371763
1764+ # qk norm
1765+ if has_qk_norm :
1766+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_q.weight" ] = checkpoint .pop (
1767+ f"joint_blocks.{ i } .x_block.attn.ln_q.weight"
1768+ )
1769+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_k.weight" ] = checkpoint .pop (
1770+ f"joint_blocks.{ i } .x_block.attn.ln_k.weight"
1771+ )
1772+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_q.weight" ] = checkpoint .pop (
1773+ f"joint_blocks.{ i } .context_block.attn.ln_q.weight"
1774+ )
1775+ converted_state_dict [f"transformer_blocks.{ i } .attn.norm_added_k.weight" ] = checkpoint .pop (
1776+ f"joint_blocks.{ i } .context_block.attn.ln_k.weight"
1777+ )
1778+
17381779 # output projections.
17391780 converted_state_dict [f"transformer_blocks.{ i } .attn.to_out.0.weight" ] = checkpoint .pop (
17401781 f"joint_blocks.{ i } .x_block.attn.proj.weight"
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17501791 f"joint_blocks.{ i } .context_block.attn.proj.bias"
17511792 )
17521793
1794+ if i in dual_attention_layers :
1795+ # Q, K, V
1796+ sample_q2 , sample_k2 , sample_v2 = torch .chunk (
1797+ checkpoint .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.weight" ), 3 , dim = 0
1798+ )
1799+ sample_q2_bias , sample_k2_bias , sample_v2_bias = torch .chunk (
1800+ checkpoint .pop (f"joint_blocks.{ i } .x_block.attn2.qkv.bias" ), 3 , dim = 0
1801+ )
1802+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = torch .cat ([sample_q2 ])
1803+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = torch .cat ([sample_q2_bias ])
1804+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = torch .cat ([sample_k2 ])
1805+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = torch .cat ([sample_k2_bias ])
1806+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = torch .cat ([sample_v2 ])
1807+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = torch .cat ([sample_v2_bias ])
1808+
1809+ # qk norm
1810+ if has_qk_norm :
1811+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_q.weight" ] = checkpoint .pop (
1812+ f"joint_blocks.{ i } .x_block.attn2.ln_q.weight"
1813+ )
1814+ converted_state_dict [f"transformer_blocks.{ i } .attn2.norm_k.weight" ] = checkpoint .pop (
1815+ f"joint_blocks.{ i } .x_block.attn2.ln_k.weight"
1816+ )
1817+
1818+ # output projections.
1819+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = checkpoint .pop (
1820+ f"joint_blocks.{ i } .x_block.attn2.proj.weight"
1821+ )
1822+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = checkpoint .pop (
1823+ f"joint_blocks.{ i } .x_block.attn2.proj.bias"
1824+ )
1825+
17531826 # norms.
17541827 converted_state_dict [f"transformer_blocks.{ i } .norm1.linear.weight" ] = checkpoint .pop (
17551828 f"joint_blocks.{ i } .x_block.adaLN_modulation.1.weight"
0 commit comments