|  | 
| 19 | 19 | import numpy as np | 
| 20 | 20 | import pytest | 
| 21 | 21 | 
 | 
| 22 |  | -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging | 
|  | 22 | +from diffusers import ( | 
|  | 23 | +    BitsAndBytesConfig, | 
|  | 24 | +    DiffusionPipeline, | 
|  | 25 | +    FluxTransformer2DModel, | 
|  | 26 | +    SanaTransformer2DModel, | 
|  | 27 | +    SD3Transformer2DModel, | 
|  | 28 | +    logging, | 
|  | 29 | +) | 
| 23 | 30 | from diffusers.utils import is_accelerate_version | 
| 24 | 31 | from diffusers.utils.testing_utils import ( | 
| 25 | 32 |     CaptureLogger, | 
| @@ -300,6 +307,33 @@ def test_device_and_dtype_assignment(self): | 
| 300 | 307 |         _ = self.model_fp16.cuda() | 
| 301 | 308 | 
 | 
| 302 | 309 | 
 | 
|  | 310 | +class Bnb8bitDeviceTests(Base8bitTests): | 
|  | 311 | +    def setUp(self) -> None: | 
|  | 312 | +        gc.collect() | 
|  | 313 | +        torch.cuda.empty_cache() | 
|  | 314 | + | 
|  | 315 | +        mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) | 
|  | 316 | +        self.model_8bit = SanaTransformer2DModel.from_pretrained( | 
|  | 317 | +            "Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers", | 
|  | 318 | +            subfolder="transformer", | 
|  | 319 | +            quantization_config=mixed_int8_config, | 
|  | 320 | +        ) | 
|  | 321 | + | 
|  | 322 | +    def tearDown(self): | 
|  | 323 | +        del self.model_8bit | 
|  | 324 | + | 
|  | 325 | +        gc.collect() | 
|  | 326 | +        torch.cuda.empty_cache() | 
|  | 327 | + | 
|  | 328 | +    def test_buffers_device_assignment(self): | 
|  | 329 | +        for buffer_name, buffer in self.model_8bit.named_buffers(): | 
|  | 330 | +            self.assertEqual( | 
|  | 331 | +                buffer.device.type, | 
|  | 332 | +                torch.device(torch_device).type, | 
|  | 333 | +                f"Expected device {torch_device} for {buffer_name} got {buffer.device}.", | 
|  | 334 | +            ) | 
|  | 335 | + | 
|  | 336 | + | 
| 303 | 337 | class BnB8bitTrainingTests(Base8bitTests): | 
| 304 | 338 |     def setUp(self): | 
| 305 | 339 |         gc.collect() | 
|  | 
0 commit comments