@@ -2008,6 +2008,82 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
20082008
20092009 return converted_state_dict
20102010
2011+ def convert_sana_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
2012+ converted_state_dict = {}
2013+ keys = list (checkpoint .keys ())
2014+
2015+ for k in keys :
2016+ if "model.diffusion_model." in k :
2017+ checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
2018+
2019+
2020+ # Patch embeddings.
2021+ converted_state_dict ["patch_embed.proj.weight" ] = checkpoint .pop ("x_embedder.proj.weight" )
2022+ converted_state_dict ["patch_embed.proj.bias" ] = checkpoint .pop ("x_embedder.proj.bias" )
2023+
2024+ # Caption projection.
2025+ converted_state_dict ["caption_projection.linear_1.weight" ] = checkpoint .pop ("y_embedder.y_proj.fc1.weight" )
2026+ converted_state_dict ["caption_projection.linear_1.bias" ] = checkpoint .pop ("y_embedder.y_proj.fc1.bias" )
2027+ converted_state_dict ["caption_projection.linear_2.weight" ] = checkpoint .pop ("y_embedder.y_proj.fc2.weight" )
2028+ converted_state_dict ["caption_projection.linear_2.bias" ] = checkpoint .pop ("y_embedder.y_proj.fc2.bias" )
2029+
2030+ # AdaLN-single LN
2031+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.weight" ] = checkpoint .pop ("t_embedder.mlp.0.weight" )
2032+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_1.bias" ] = checkpoint .pop ("t_embedder.mlp.0.bias" )
2033+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.weight" ] = checkpoint .pop ("t_embedder.mlp.2.weight" )
2034+ converted_state_dict ["time_embed.emb.timestep_embedder.linear_2.bias" ] = checkpoint .pop ("t_embedder.mlp.2.bias" )
2035+
2036+ # Shared norm.
2037+ converted_state_dict ["time_embed.linear.weight" ] = checkpoint .pop ("t_block.1.weight" )
2038+ converted_state_dict ["time_embed.linear.bias" ] = checkpoint .pop ("t_block.1.bias" )
2039+
2040+ # y norm
2041+ converted_state_dict ["caption_norm.weight" ] = checkpoint .pop ("attention_y_norm.weight" )
2042+
2043+ # Transformer blocks.
2044+ layer_num = len ([key for key in checkpoint .keys () if "blocks" in key and "scale_shift_table" in key ])
2045+ for depth in range (layer_num ):
2046+ # Transformer blocks.
2047+ converted_state_dict [f"transformer_blocks.{ depth } .scale_shift_table" ] = checkpoint .pop (f"blocks.{ depth } .scale_shift_table" )
2048+
2049+ # Self attention.
2050+ q , k , v = torch .chunk (checkpoint .pop (f"blocks.{ depth } .attn.qkv.weight" ), 3 , dim = 0 )
2051+ converted_state_dict [f"transformer_blocks.{ depth } .attn1.to_q.weight" ] = q
2052+ converted_state_dict [f"transformer_blocks.{ depth } .attn1.to_k.weight" ] = k
2053+ converted_state_dict [f"transformer_blocks.{ depth } .attn1.to_v.weight" ] = v
2054+ # Projection.
2055+ converted_state_dict [f"transformer_blocks.{ depth } .attn1.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ depth } .attn.proj.weight" )
2056+ converted_state_dict [f"transformer_blocks.{ depth } .attn1.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ depth } .attn.proj.bias" )
2057+
2058+ # Feed-forward.
2059+ converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_inverted.weight" ] = checkpoint .pop (f"blocks.{ depth } .mlp.inverted_conv.conv.weight" )
2060+ converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_inverted.bias" ] = checkpoint .pop (f"blocks.{ depth } .mlp.inverted_conv.conv.bias" )
2061+ converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_depth.weight" ] = checkpoint .pop (f"blocks.{ depth } .mlp.depth_conv.conv.weight" )
2062+ converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_depth.bias" ] = checkpoint .pop (f"blocks.{ depth } .mlp.depth_conv.conv.bias" )
2063+ converted_state_dict [f"transformer_blocks.{ depth } .ff.conv_point.weight" ] = checkpoint .pop (f"blocks.{ depth } .mlp.point_conv.conv.weight" )
2064+
2065+ # Cross-attention.
2066+ q = checkpoint .pop (f"blocks.{ depth } .cross_attn.q_linear.weight" )
2067+ q_bias = checkpoint .pop (f"blocks.{ depth } .cross_attn.q_linear.bias" )
2068+ k , v = torch .chunk (checkpoint .pop (f"blocks.{ depth } .cross_attn.kv_linear.weight" ), 2 , dim = 0 )
2069+ k_bias , v_bias = torch .chunk (checkpoint .pop (f"blocks.{ depth } .cross_attn.kv_linear.bias" ), 2 , dim = 0 )
2070+
2071+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_q.weight" ] = q
2072+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_q.bias" ] = q_bias
2073+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_k.weight" ] = k
2074+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_k.bias" ] = k_bias
2075+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_v.weight" ] = v
2076+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_v.bias" ] = v_bias
2077+
2078+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_out.0.weight" ] = checkpoint .pop (f"blocks.{ depth } .cross_attn.proj.weight" )
2079+ converted_state_dict [f"transformer_blocks.{ depth } .attn2.to_out.0.bias" ] = checkpoint .pop (f"blocks.{ depth } .cross_attn.proj.bias" )
2080+
2081+ # Final block.
2082+ converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_layer.linear.weight" )
2083+ converted_state_dict ["proj_out.bias" ] = checkpoint .pop ("final_layer.linear.bias" )
2084+ converted_state_dict ["scale_shift_table" ] = checkpoint .pop ("final_layer.scale_shift_table" )
2085+
2086+ return converted_state_dict
20112087
20122088def is_t5_in_single_file (checkpoint ):
20132089 if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint :
0 commit comments