|
28 | 28 | ) |
29 | 29 | from diffusers.pipelines.bria import BriaPipeline |
30 | 30 | from diffusers.utils.testing_utils import ( |
| 31 | + backend_empty_cache, |
31 | 32 | enable_full_determinism, |
32 | 33 | numpy_cosine_similarity_distance, |
33 | | - require_accelerator, |
34 | | - require_torch_gpu, |
| 34 | + require_torch_accelerator, |
35 | 35 | slow, |
36 | 36 | torch_device, |
37 | 37 | ) |
@@ -149,7 +149,7 @@ def test_image_output_shape(self): |
149 | 149 | assert (output_height, output_width) == (expected_height, expected_width) |
150 | 150 |
|
151 | 151 | @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") |
152 | | - @require_accelerator |
| 152 | + @require_torch_accelerator |
153 | 153 | def test_save_load_float16(self, expected_max_diff=1e-2): |
154 | 154 | components = self.get_dummy_components() |
155 | 155 | for name, module in components.items(): |
@@ -237,20 +237,20 @@ def test_torch_dtype_dict(self): |
237 | 237 |
|
238 | 238 |
|
239 | 239 | @slow |
240 | | -@require_torch_gpu |
| 240 | +@require_torch_accelerator |
241 | 241 | class BriaPipelineSlowTests(unittest.TestCase): |
242 | 242 | pipeline_class = BriaPipeline |
243 | 243 | repo_id = "briaai/BRIA-3.2" |
244 | 244 |
|
245 | 245 | def setUp(self): |
246 | 246 | super().setUp() |
247 | 247 | gc.collect() |
248 | | - torch.cuda.empty_cache() |
| 248 | + backend_empty_cache(torch_device) |
249 | 249 |
|
250 | 250 | def tearDown(self): |
251 | 251 | super().tearDown() |
252 | 252 | gc.collect() |
253 | | - torch.cuda.empty_cache() |
| 253 | + backend_empty_cache(torch_device) |
254 | 254 |
|
255 | 255 | def get_inputs(self, device, seed=0): |
256 | 256 | generator = torch.Generator(device="cpu").manual_seed(seed) |
|
0 commit comments