diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 1049bfecbaab..a8aff679b5b6 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -523,13 +523,15 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self): torch_dtype=torch.float16, device_map=torch_device, ) + # CUDA device placement works. + device = torch_device if torch_device != "rocm" else "cuda" pipeline_8bit = DiffusionPipeline.from_pretrained( self.model_name, transformer=transformer_8bit, text_encoder_3=text_encoder_3_8bit, torch_dtype=torch.float16, - ).to("cuda") + ).to(device) # Check if inference works. _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 04ebf9e159f4..d458a3e6d554 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,4 +1,10 @@ from diffusers.utils import is_torch_available +from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + torch_device, +) if is_torch_available(): @@ -30,9 +36,9 @@ def forward(self, input, *args, **kwargs): @torch.no_grad() @torch.inference_mode() def get_memory_consumption_stat(model, inputs): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) model(**inputs) - max_memory_mem_allocated = torch.cuda.max_memory_allocated() - return max_memory_mem_allocated + max_mem_allocated = backend_max_memory_allocated(torch_device) + return max_mem_allocated