121121 "blocks.0.cross_attn.q_linear.weight" ,
122122 "blocks.0.cross_attn.q_linear.bias" ,
123123 "blocks.0.cross_attn.kv_linear.weight" ,
124- "blocks.0.cross_attn.kv_linear.bias"
124+ "blocks.0.cross_attn.kv_linear.bias" ,
125125 ],
126126}
127127
182182 "hunyuan-video" : {"pretrained_model_name_or_path" : "hunyuanvideo-community/HunyuanVideo" },
183183 "instruct-pix2pix" : {"pretrained_model_name_or_path" : "timbrooks/instruct-pix2pix" },
184184 "lumina2" : {"pretrained_model_name_or_path" : "Alpha-VLLM/Lumina-Image-2.0" },
185- "sana" : {"pretrained_model_name_or_path" : "Efficient-Large-Model/Sana_1600M_1024px_diffusers" },
185+ "sana" : {"pretrained_model_name_or_path" : "Efficient-Large-Model/Sana_1600M_1024px_diffusers" },
186186}
187187
188188# Use to configure model sample size when original config is provided
@@ -2878,16 +2878,19 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
28782878
28792879 num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "blocks" in k ))[- 1 ] + 1 # noqa: C401
28802880
2881-
28822881 # Positional and patch embeddings.
28832882 checkpoint .pop ("pos_embed" )
28842883 converted_state_dict ["patch_embed.proj.weight" ] = checkpoint .pop ("x_embedder.proj.weight" )
28852884 converted_state_dict ["patch_embed.proj.bias" ] = checkpoint .pop ("x_embedder.proj.bias" )
28862885
28872886 # Timestep embeddings.
2888- converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.weight" ] = checkpoint .pop ("t_embedder.mlp.0.weight" )
2887+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.weight" ] = checkpoint .pop (
2888+ "t_embedder.mlp.0.weight"
2889+ )
28892890 converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.bias" ] = checkpoint .pop ("t_embedder.mlp.0.bias" )
2890- converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.weight" ] = checkpoint .pop ("t_embedder.mlp.2.weight" )
2891+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.weight" ] = checkpoint .pop (
2892+ "t_embedder.mlp.2.weight"
2893+ )
28912894 converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.bias" ] = checkpoint .pop ("t_embedder.mlp.2.bias" )
28922895 converted_state_dict ["time_embed.linear.weight" ] = checkpoint .pop ("t_block.1.weight" )
28932896 converted_state_dict ["time_embed.linear.bias" ] = checkpoint .pop ("t_block.1.bias" )
@@ -2900,9 +2903,10 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
29002903 converted_state_dict ["caption_projection.linear_2.bias" ] = checkpoint .pop ("y_embedder.y_proj.fc2.bias" )
29012904 converted_state_dict ["caption_norm.weight" ] = checkpoint .pop ("attention_y_norm.weight" )
29022905
2903-
29042906 for i in range (num_layers ):
2905- converted_state_dict [f"transformer_blocks.{ i } .scale_shift_table" ] = checkpoint .pop (f"blocks.{ i } .scale_shift_table" )
2907+ converted_state_dict [f"transformer_blocks.{ i } .scale_shift_table" ] = checkpoint .pop (
2908+ f"blocks.{ i } .scale_shift_table"
2909+ )
29062910
29072911 # Self-Attention
29082912 sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"blocks.{ i } .attn.qkv.weight" ), 3 , dim = 0 )
@@ -2911,30 +2915,56 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
29112915 converted_state_dict [f"transformer_blocks.{ i } .attn1.to_v.weight" ] = torch .cat ([sample_v ])
29122916
29132917 # Output Projections
2914- converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ i } .attn.proj.weight" )
2915- converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ i } .attn.proj.bias" )
2918+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.weight" ] = checkpoint .pop (
2919+ f"blocks.{ i } .attn.proj.weight"
2920+ )
2921+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.bias" ] = checkpoint .pop (
2922+ f"blocks.{ i } .attn.proj.bias"
2923+ )
29162924
29172925 # Cross-Attention
2918- converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.q_linear.weight" )
2919- converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.q_linear.bias" )
2926+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = checkpoint .pop (
2927+ f"blocks.{ i } .cross_attn.q_linear.weight"
2928+ )
2929+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = checkpoint .pop (
2930+ f"blocks.{ i } .cross_attn.q_linear.bias"
2931+ )
29202932
2921- linear_sample_k , linear_sample_v = torch .chunk (checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.weight" ), 2 , dim = 0 )
2922- linear_sample_k_bias , linear_sample_v_bias = torch .chunk (checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.bias" ), 2 , dim = 0 )
2933+ linear_sample_k , linear_sample_v = torch .chunk (
2934+ checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.weight" ), 2 , dim = 0
2935+ )
2936+ linear_sample_k_bias , linear_sample_v_bias = torch .chunk (
2937+ checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.bias" ), 2 , dim = 0
2938+ )
29232939 converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = linear_sample_k
29242940 converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = linear_sample_v
29252941 converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = linear_sample_k_bias
29262942 converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = linear_sample_v_bias
29272943
29282944 # Output Projections
2929- converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.proj.weight" )
2930- converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.proj.bias" )
2945+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = checkpoint .pop (
2946+ f"blocks.{ i } .cross_attn.proj.weight"
2947+ )
2948+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = checkpoint .pop (
2949+ f"blocks.{ i } .cross_attn.proj.bias"
2950+ )
29312951
29322952 # MLP
2933- converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.inverted_conv.conv.weight" )
2934- converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.bias" ] = checkpoint .pop (f"blocks.{ i } .mlp.inverted_conv.conv.bias" )
2935- converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.depth_conv.conv.weight" )
2936- converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.bias" ] = checkpoint .pop (f"blocks.{ i } .mlp.depth_conv.conv.bias" )
2937- converted_state_dict [f"transformer_blocks.{ i } .ff.conv_point.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.point_conv.conv.weight" )
2953+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.weight" ] = checkpoint .pop (
2954+ f"blocks.{ i } .mlp.inverted_conv.conv.weight"
2955+ )
2956+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.bias" ] = checkpoint .pop (
2957+ f"blocks.{ i } .mlp.inverted_conv.conv.bias"
2958+ )
2959+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.weight" ] = checkpoint .pop (
2960+ f"blocks.{ i } .mlp.depth_conv.conv.weight"
2961+ )
2962+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.bias" ] = checkpoint .pop (
2963+ f"blocks.{ i } .mlp.depth_conv.conv.bias"
2964+ )
2965+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_point.weight" ] = checkpoint .pop (
2966+ f"blocks.{ i } .mlp.point_conv.conv.weight"
2967+ )
29382968
29392969 # Final layer
29402970 converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
0 commit comments