Skip to content

Commit 1f6defd

Browse files
committed
update
1 parent 710e18b commit 1f6defd

File tree

4 files changed

+18
-34
lines changed

4 files changed

+18
-34
lines changed

tests/single_file/single_file_testing_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def test_single_file_model_parameters(self):
119119
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
120120
)
121121

122+
def test_checkpoint_altered_keys_loading(self):
123+
# Test loading with checkpoints that have altered keys
124+
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
125+
return
126+
127+
for ckpt_path in self.alternate_keys_ckpt_paths:
128+
backend_empty_cache(torch_device)
129+
130+
single_file_kwargs = {}
131+
if hasattr(self, "torch_dtype") and self.torch_dtype:
132+
single_file_kwargs["torch_dtype"] = self.torch_dtype
133+
134+
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
135+
136+
del model
137+
gc.collect()
138+
backend_empty_cache(torch_device)
139+
122140

123141
class SDSingleFileTesterMixin:
124142
single_file_kwargs = {}

tests/single_file/test_lumina2_transformer.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import unittest
1817

1918
from diffusers import (
2019
Lumina2Transformer2DModel,
2120
)
2221

2322
from ..testing_utils import (
24-
backend_empty_cache,
2523
enable_full_determinism,
26-
torch_device,
2724
)
2825
from .single_file_testing_utils import SingleFileModelTesterMixin
2926

@@ -40,12 +37,3 @@ class Lumina2Transformer2DModelSingleFileTests(SingleFileModelTesterMixin, unitt
4037

4138
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
4239
subfolder = "transformer"
43-
44-
def test_checkpoint_loading(self):
45-
for ckpt_path in self.alternate_keys_ckpt_paths:
46-
backend_empty_cache(torch_device)
47-
model = self.model_class.from_single_file(ckpt_path)
48-
49-
del model
50-
gc.collect()
51-
backend_empty_cache(torch_device)

tests/single_file/test_model_flux_transformer_single_file.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,8 @@ class FluxTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
3737
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
3838

3939
repo_id = "black-forest-labs/FLUX.1-dev"
40-
4140
subfolder = "transformer"
4241

43-
def test_checkpoint_loading(self):
44-
for ckpt_path in self.alternate_keys_ckpt_paths:
45-
backend_empty_cache(torch_device)
46-
model = self.model_class.from_single_file(ckpt_path)
47-
48-
del model
49-
gc.collect()
50-
backend_empty_cache(torch_device)
51-
5242
def test_device_map_cuda(self):
5343
backend_empty_cache(torch_device)
5444
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import gc
21
import unittest
32

43
from diffusers import (
54
SanaTransformer2DModel,
65
)
76

87
from ..testing_utils import (
9-
backend_empty_cache,
108
enable_full_determinism,
11-
torch_device,
129
)
1310
from .single_file_testing_utils import SingleFileModelTesterMixin
1411

@@ -27,12 +24,3 @@ class SanaTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
2724

2825
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
2926
subfolder = "transformer"
30-
31-
def test_checkpoint_loading(self):
32-
for ckpt_path in self.alternate_keys_ckpt_paths:
33-
backend_empty_cache(torch_device)
34-
model = self.model_class.from_single_file(ckpt_path)
35-
36-
del model
37-
gc.collect()
38-
backend_empty_cache(torch_device)

0 commit comments

Comments
 (0)