Skip to content

Commit 4829c9e

Browse files
committed
added diffusers mapping script
1 parent 7264cd3 commit 4829c9e

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
convert_ltx_transformer_checkpoint_to_diffusers,
3737
convert_ltx_vae_checkpoint_to_diffusers,
3838
convert_lumina2_to_diffusers,
39+
convert_sana_transformer_to_diffusers,
3940
convert_mochi_transformer_checkpoint_to_diffusers,
4041
convert_sd3_transformer_checkpoint_to_diffusers,
4142
convert_stable_cascade_unet_single_file_to_diffusers,
@@ -117,6 +118,10 @@
117118
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
118119
"default_subfolder": "transformer",
119120
},
121+
"SanaTransformer2DModel": {
122+
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
123+
"default_subfolder": "transformer",
124+
}
120125
}
121126

122127

src/diffusers/loaders/single_file_utils.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@
117117
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
118118
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
119119
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
120+
"sana": [
121+
"blocks.0.cross_attn.q_linear.weight",
122+
"blocks.0.cross_attn.q_linear.bias",
123+
"blocks.0.cross_attn.kv_linear.weight",
124+
"blocks.0.cross_attn.kv_linear.bias"
125+
],
120126
}
121127

122128
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -176,6 +182,7 @@
176182
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
177183
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
178184
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
185+
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px"},
179186
}
180187

181188
# Use to configure model sample size when original config is provided
@@ -662,6 +669,9 @@ def infer_diffusers_model_type(checkpoint):
662669
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
663670
model_type = "lumina2"
664671

672+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
673+
model_type = "sana"
674+
665675
else:
666676
model_type = "v1"
667677

@@ -2857,3 +2867,76 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
28572867
converted_state_dict[diffusers_key] = checkpoint.pop(key)
28582868

28592869
return converted_state_dict
2870+
2871+
2872+
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
2873+
converted_state_dict = {}
2874+
keys = list(checkpoint.keys())
2875+
for k in keys:
2876+
if "model.diffusion_model." in k:
2877+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2878+
2879+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
2880+
2881+
# Positional and patch embeddings.
2882+
checkpoint.pop("pos_embed")
2883+
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2884+
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2885+
2886+
# Timestep embeddings.
2887+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2888+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2889+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2890+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2891+
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
2892+
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
2893+
2894+
# Caption Projection.
2895+
converted_state_dict["caption_proj.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
2896+
converted_state_dict["caption_proj.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
2897+
converted_state_dict["caption_proj.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
2898+
converted_state_dict["caption_proj.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
2899+
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
2900+
2901+
2902+
for i in range(num_layers):
2903+
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(f"blocks.{i}.scale_shift_table")
2904+
2905+
# Self-Attention
2906+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
2907+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
2908+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
2909+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
2910+
2911+
# Output Projections
2912+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.attn.proj.weight")
2913+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.attn.proj.bias")
2914+
2915+
# Cross-Attention
2916+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.weight")
2917+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.bias")
2918+
2919+
linear_sample_k, linear_sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0)
2920+
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0)
2921+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
2922+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
2923+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
2924+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
2925+
2926+
# Output Projections
2927+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.weight")
2928+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.bias")
2929+
2930+
# MLP
2931+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.weight")
2932+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.bias")
2933+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.weight")
2934+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.bias")
2935+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{i}.mlp.point_conv.conv.weight")
2936+
2937+
# Final layer
2938+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2939+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2940+
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
2941+
2942+
return converted_state_dict

0 commit comments

Comments
 (0)