Skip to content

Commit 56a7718

Browse files
committed
updated tests
1 parent 17edc43 commit 56a7718

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

tests/single_file/test_sana_transformer.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import gc
22
import unittest
33

4+
import torch
5+
46
from diffusers import (
57
SanaTransformer2DModel,
68
)
@@ -18,6 +20,10 @@
1820
@require_torch_accelerator
1921
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
2022
model_class = SanaTransformer2DModel
23+
ckpt_path = "https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
24+
alternate_keys_ckpt_paths = [
25+
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
26+
]
2127

2228
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
2329

@@ -32,4 +38,22 @@ def tearDown(self):
3238
backend_empty_cache(torch_device)
3339

3440
def test_single_file_components(self):
35-
_ = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
41+
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
42+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
43+
44+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
45+
for param_name, param_value in model_single_file.config.items():
46+
if param_name in PARAMS_TO_IGNORE:
47+
continue
48+
assert (
49+
model.config[param_name] == param_value
50+
), f"{param_name} differs between single file loading and pretrained loading"
51+
52+
def test_checkpoint_loading(self):
53+
for ckpt_path in self.alternate_keys_ckpt_paths:
54+
torch.cuda.empty_cache()
55+
model = self.model_class.from_single_file(ckpt_path)
56+
57+
del model
58+
gc.collect()
59+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)