Skip to content

Commit abb5a58

Browse files
committed
added single file support for sana transformers
1 parent 1450c2a commit abb5a58

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
convert_lumina2_to_diffusers,
3939
convert_mochi_transformer_checkpoint_to_diffusers,
4040
convert_sd3_transformer_checkpoint_to_diffusers,
41+
convert_sana_transformer_checkpoint_to_diffusers,
4142
convert_stable_cascade_unet_single_file_to_diffusers,
4243
create_controlnet_diffusers_config_from_ldm,
4344
create_unet_diffusers_config_from_ldm,
@@ -82,6 +83,10 @@
8283
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
8384
"default_subfolder": "transformer",
8485
},
86+
"SanaTransformer2DModel":{
87+
"checkpoint_mapping_fn": convert_sana_transformer_checkpoint_to_diffusers,
88+
"default_subfolder": "transformer",
89+
},
8590
"MotionAdapter": {
8691
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
8792
},

src/diffusers/loaders/single_file_utils.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,82 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
20082008

20092009
return converted_state_dict
20102010

2011+
def convert_sana_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2012+
converted_state_dict = {}
2013+
keys = list(checkpoint.keys())
2014+
2015+
for k in keys:
2016+
if "model.diffusion_model." in k:
2017+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2018+
2019+
2020+
# Patch embeddings.
2021+
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2022+
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2023+
2024+
# Caption projection.
2025+
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
2026+
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
2027+
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
2028+
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
2029+
2030+
# AdaLN-single LN
2031+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2032+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2033+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2034+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2035+
2036+
# Shared norm.
2037+
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
2038+
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
2039+
2040+
# y norm
2041+
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
2042+
2043+
# Transformer blocks.
2044+
layer_num = len([key for key in checkpoint.keys() if "blocks" in key and "scale_shift_table" in key])
2045+
for depth in range(layer_num):
2046+
# Transformer blocks.
2047+
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = checkpoint.pop(f"blocks.{depth}.scale_shift_table")
2048+
2049+
# Self attention.
2050+
q, k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
2051+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
2052+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
2053+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
2054+
# Projection.
2055+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = checkpoint.pop(f"blocks.{depth}.attn.proj.weight")
2056+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = checkpoint.pop(f"blocks.{depth}.attn.proj.bias")
2057+
2058+
# Feed-forward.
2059+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = checkpoint.pop(f"blocks.{depth}.mlp.inverted_conv.conv.weight")
2060+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = checkpoint.pop(f"blocks.{depth}.mlp.inverted_conv.conv.bias")
2061+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{depth}.mlp.depth_conv.conv.weight")
2062+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{depth}.mlp.depth_conv.conv.bias")
2063+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{depth}.mlp.point_conv.conv.weight")
2064+
2065+
# Cross-attention.
2066+
q = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
2067+
q_bias = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
2068+
k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
2069+
k_bias, v_bias = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
2070+
2071+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
2072+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
2073+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
2074+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
2075+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
2076+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
2077+
2078+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = checkpoint.pop(f"blocks.{depth}.cross_attn.proj.weight")
2079+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = checkpoint.pop(f"blocks.{depth}.cross_attn.proj.bias")
2080+
2081+
# Final block.
2082+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2083+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2084+
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
2085+
2086+
return converted_state_dict
20112087

20122088
def is_t5_in_single_file(checkpoint):
20132089
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:

0 commit comments

Comments
 (0)