Skip to content

Commit 3d56e94

Browse files
committed
add test
1 parent 570d811 commit 3d56e94

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
import numpy as np
2020
import pytest
2121

22-
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
22+
from diffusers import (
23+
BitsAndBytesConfig,
24+
DiffusionPipeline,
25+
FluxTransformer2DModel,
26+
SanaTransformer2DModel,
27+
SD3Transformer2DModel,
28+
logging,
29+
)
2330
from diffusers.utils import is_accelerate_version
2431
from diffusers.utils.testing_utils import (
2532
CaptureLogger,
@@ -300,6 +307,33 @@ def test_device_and_dtype_assignment(self):
300307
_ = self.model_fp16.cuda()
301308

302309

310+
class Bnb8bitDeviceTests(Base8bitTests):
311+
def setUp(self) -> None:
312+
gc.collect()
313+
torch.cuda.empty_cache()
314+
315+
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
316+
self.model_8bit = SanaTransformer2DModel.from_pretrained(
317+
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
318+
subfolder="transformer",
319+
quantization_config=mixed_int8_config,
320+
)
321+
322+
def tearDown(self):
323+
del self.model_8bit
324+
325+
gc.collect()
326+
torch.cuda.empty_cache()
327+
328+
def test_buffers_device_assignment(self):
329+
for buffer_name, buffer in self.model_8bit.named_buffers():
330+
self.assertEqual(
331+
buffer.device.type,
332+
torch.device(torch_device).type,
333+
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
334+
)
335+
336+
303337
class BnB8bitTrainingTests(Base8bitTests):
304338
def setUp(self):
305339
gc.collect()

0 commit comments

Comments
 (0)