Skip to content

Commit e3046a5

Browse files
committed
corrected code quality
1 parent 7c6ddb2 commit e3046a5

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
"SanaTransformer2DModel": {
122122
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
123123
"default_subfolder": "transformer",
124-
}
124+
},
125125
}
126126

127127

src/diffusers/loaders/single_file_utils.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
"blocks.0.cross_attn.q_linear.weight",
122122
"blocks.0.cross_attn.q_linear.bias",
123123
"blocks.0.cross_attn.kv_linear.weight",
124-
"blocks.0.cross_attn.kv_linear.bias"
124+
"blocks.0.cross_attn.kv_linear.bias",
125125
],
126126
}
127127

@@ -182,7 +182,7 @@
182182
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
183183
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
184184
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
185-
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
185+
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
186186
}
187187

188188
# Use to configure model sample size when original config is provided
@@ -2878,16 +2878,19 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
28782878

28792879
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
28802880

2881-
28822881
# Positional and patch embeddings.
28832882
checkpoint.pop("pos_embed")
28842883
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
28852884
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
28862885

28872886
# Timestep embeddings.
2888-
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2887+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
2888+
"t_embedder.mlp.0.weight"
2889+
)
28892890
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2890-
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2891+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
2892+
"t_embedder.mlp.2.weight"
2893+
)
28912894
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
28922895
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
28932896
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
@@ -2900,9 +2903,10 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
29002903
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
29012904
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
29022905

2903-
29042906
for i in range(num_layers):
2905-
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(f"blocks.{i}.scale_shift_table")
2907+
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
2908+
f"blocks.{i}.scale_shift_table"
2909+
)
29062910

29072911
# Self-Attention
29082912
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
@@ -2911,30 +2915,56 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
29112915
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
29122916

29132917
# Output Projections
2914-
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.attn.proj.weight")
2915-
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.attn.proj.bias")
2918+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
2919+
f"blocks.{i}.attn.proj.weight"
2920+
)
2921+
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
2922+
f"blocks.{i}.attn.proj.bias"
2923+
)
29162924

29172925
# Cross-Attention
2918-
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.weight")
2919-
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.q_linear.bias")
2926+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
2927+
f"blocks.{i}.cross_attn.q_linear.weight"
2928+
)
2929+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
2930+
f"blocks.{i}.cross_attn.q_linear.bias"
2931+
)
29202932

2921-
linear_sample_k, linear_sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0)
2922-
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0)
2933+
linear_sample_k, linear_sample_v = torch.chunk(
2934+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
2935+
)
2936+
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
2937+
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
2938+
)
29232939
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
29242940
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
29252941
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
29262942
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
29272943

29282944
# Output Projections
2929-
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.weight")
2930-
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(f"blocks.{i}.cross_attn.proj.bias")
2945+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
2946+
f"blocks.{i}.cross_attn.proj.weight"
2947+
)
2948+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
2949+
f"blocks.{i}.cross_attn.proj.bias"
2950+
)
29312951

29322952
# MLP
2933-
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.weight")
2934-
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(f"blocks.{i}.mlp.inverted_conv.conv.bias")
2935-
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.weight")
2936-
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(f"blocks.{i}.mlp.depth_conv.conv.bias")
2937-
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(f"blocks.{i}.mlp.point_conv.conv.weight")
2953+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
2954+
f"blocks.{i}.mlp.inverted_conv.conv.weight"
2955+
)
2956+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
2957+
f"blocks.{i}.mlp.inverted_conv.conv.bias"
2958+
)
2959+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
2960+
f"blocks.{i}.mlp.depth_conv.conv.weight"
2961+
)
2962+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
2963+
f"blocks.{i}.mlp.depth_conv.conv.bias"
2964+
)
2965+
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
2966+
f"blocks.{i}.mlp.point_conv.conv.weight"
2967+
)
29382968

29392969
# Final layer
29402970
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")

tests/single_file/test_sana_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
@require_torch_accelerator
2121
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
2222
model_class = SanaTransformer2DModel
23-
ckpt_path = "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
23+
ckpt_path = (
24+
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
25+
)
2426
alternate_keys_ckpt_paths = [
2527
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
2628
]

0 commit comments

Comments
 (0)