Skip to content

Commit a990c1c

Browse files
committed
added testcase
1 parent 4829c9e commit a990c1c

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
183183
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
184184
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
185-
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px"},
185+
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
186186
}
187187

188188
# Use to configure model sample size when original config is provided
@@ -2878,6 +2878,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
28782878

28792879
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
28802880

2881+
28812882
# Positional and patch embeddings.
28822883
checkpoint.pop("pos_embed")
28832884
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
@@ -2892,6 +2893,7 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
28922893
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
28932894

28942895
# Caption Projection.
2896+
checkpoint.pop("y_embedder.y_embedding")
28952897
converted_state_dict["caption_proj.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
28962898
converted_state_dict["caption_proj.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
28972899
converted_state_dict["caption_proj.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import gc
2+
import unittest
3+
4+
import torch
5+
6+
from diffusers import (
7+
SanaTransformer2DModel,
8+
)
9+
from diffusers.utils.testing_utils import (
10+
backend_empty_cache,
11+
enable_full_determinism,
12+
require_torch_accelerator,
13+
torch_device,
14+
)
15+
16+
17+
enable_full_determinism()
18+
19+
20+
@require_torch_accelerator
21+
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
22+
model_class = SanaTransformer2DModel
23+
24+
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
25+
26+
def setUp(self):
27+
super().setUp()
28+
gc.collect()
29+
backend_empty_cache(torch_device)
30+
31+
def tearDown(self):
32+
super().tearDown()
33+
gc.collect()
34+
backend_empty_cache(torch_device)
35+
36+
def test_single_file_components(self):
37+
_ = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")

0 commit comments

Comments
 (0)