117117 "hunyuan-video" : "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias" ,
118118 "instruct-pix2pix" : "model.diffusion_model.input_blocks.0.0.weight" ,
119119 "lumina2" : ["model.diffusion_model.cap_embedder.0.weight" , "cap_embedder.0.weight" ],
120+ "sana" : [
121+ "blocks.0.cross_attn.q_linear.weight" ,
122+ "blocks.0.cross_attn.q_linear.bias" ,
123+ "blocks.0.cross_attn.kv_linear.weight" ,
124+ "blocks.0.cross_attn.kv_linear.bias"
125+ ],
120126}
121127
122128DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
176182 "hunyuan-video" : {"pretrained_model_name_or_path" : "hunyuanvideo-community/HunyuanVideo" },
177183 "instruct-pix2pix" : {"pretrained_model_name_or_path" : "timbrooks/instruct-pix2pix" },
178184 "lumina2" : {"pretrained_model_name_or_path" : "Alpha-VLLM/Lumina-Image-2.0" },
185+ "sana" : {"pretrained_model_name_or_path" : "Efficient-Large-Model/Sana_1600M_1024px" },
179186}
180187
181188# Use to configure model sample size when original config is provided
@@ -662,6 +669,9 @@ def infer_diffusers_model_type(checkpoint):
662669 elif any (key in checkpoint for key in CHECKPOINT_KEY_NAMES ["lumina2" ]):
663670 model_type = "lumina2"
664671
672+ elif any (key in checkpoint for key in CHECKPOINT_KEY_NAMES ["sana" ]):
673+ model_type = "sana"
674+
665675 else :
666676 model_type = "v1"
667677
@@ -2857,3 +2867,76 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
28572867 converted_state_dict [diffusers_key ] = checkpoint .pop (key )
28582868
28592869 return converted_state_dict
2870+
2871+
2872+ def convert_sana_transformer_to_diffusers (checkpoint , ** kwargs ):
2873+ converted_state_dict = {}
2874+ keys = list (checkpoint .keys ())
2875+ for k in keys :
2876+ if "model.diffusion_model." in k :
2877+ checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
2878+
2879+ num_layers = list (set (int (k .split ("." , 2 )[1 ]) for k in checkpoint if "blocks" in k ))[- 1 ] + 1 # noqa: C401
2880+
2881+ # Positional and patch embeddings.
2882+ checkpoint .pop ("pos_embed" )
2883+ converted_state_dict ["patch_embed.proj.weight" ] = checkpoint .pop ("x_embedder.proj.weight" )
2884+ converted_state_dict ["patch_embed.proj.bias" ] = checkpoint .pop ("x_embedder.proj.bias" )
2885+
2886+ # Timestep embeddings.
2887+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.weight" ] = checkpoint .pop ("t_embedder.mlp.0.weight" )
2888+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.bias" ] = checkpoint .pop ("t_embedder.mlp.0.bias" )
2889+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.weight" ] = checkpoint .pop ("t_embedder.mlp.2.weight" )
2890+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.bias" ] = checkpoint .pop ("t_embedder.mlp.2.bias" )
2891+ converted_state_dict ["time_embed.linear.weight" ] = checkpoint .pop ("t_block.1.weight" )
2892+ converted_state_dict ["time_embed.linear.bias" ] = checkpoint .pop ("t_block.1.bias" )
2893+
2894+ # Caption Projection.
2895+ converted_state_dict ["caption_proj.linear_1.weight" ] = checkpoint .pop ("y_embedder.y_proj.fc1.weight" )
2896+ converted_state_dict ["caption_proj.linear_1.bias" ] = checkpoint .pop ("y_embedder.y_proj.fc1.bias" )
2897+ converted_state_dict ["caption_proj.linear_2.weight" ] = checkpoint .pop ("y_embedder.y_proj.fc2.weight" )
2898+ converted_state_dict ["caption_proj.linear_2.bias" ] = checkpoint .pop ("y_embedder.y_proj.fc2.bias" )
2899+ converted_state_dict ["caption_norm.weight" ] = checkpoint .pop ("attention_y_norm.weight" )
2900+
2901+
2902+ for i in range (num_layers ):
2903+ converted_state_dict [f"transformer_blocks.{ i } .scale_shift_table" ] = checkpoint .pop (f"blocks.{ i } .scale_shift_table" )
2904+
2905+ # Self-Attention
2906+ sample_q , sample_k , sample_v = torch .chunk (checkpoint .pop (f"blocks.{ i } .attn.qkv.weight" ), 3 , dim = 0 )
2907+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_q.weight" ] = torch .cat ([sample_q ])
2908+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_k.weight" ] = torch .cat ([sample_k ])
2909+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_v.weight" ] = torch .cat ([sample_v ])
2910+
2911+ # Output Projections
2912+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ i } .attn.proj.weight" )
2913+ converted_state_dict [f"transformer_blocks.{ i } .attn1.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ i } .attn.proj.bias" )
2914+
2915+ # Cross-Attention
2916+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.weight" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.q_linear.weight" )
2917+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_q.bias" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.q_linear.bias" )
2918+
2919+ linear_sample_k , linear_sample_v = torch .chunk (checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.weight" ), 2 , dim = 0 )
2920+ linear_sample_k_bias , linear_sample_v_bias = torch .chunk (checkpoint .pop (f"blocks.{ i } .cross_attn.kv_linear.bias" ), 2 , dim = 0 )
2921+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.weight" ] = linear_sample_k
2922+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.weight" ] = linear_sample_v
2923+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_k.bias" ] = linear_sample_k_bias
2924+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_v.bias" ] = linear_sample_v_bias
2925+
2926+ # Output Projections
2927+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.proj.weight" )
2928+ converted_state_dict [f"transformer_blocks.{ i } .attn2.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ i } .cross_attn.proj.bias" )
2929+
2930+ # MLP
2931+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.inverted_conv.conv.weight" )
2932+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_inverted.bias" ] = checkpoint .pop (f"blocks.{ i } .mlp.inverted_conv.conv.bias" )
2933+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.depth_conv.conv.weight" )
2934+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_depth.bias" ] = checkpoint .pop (f"blocks.{ i } .mlp.depth_conv.conv.bias" )
2935+ converted_state_dict [f"transformer_blocks.{ i } .ff.conv_point.weight" ] = checkpoint .pop (f"blocks.{ i } .mlp.point_conv.conv.weight" )
2936+
2937+ # Final layer
2938+ converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
2939+ converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
2940+ converted_state_dict ["scale_shift_table" ] = checkpoint .pop ("final_layer.scale_shift_table" )
2941+
2942+ return converted_state_dict
0 commit comments