Skip to content

Commit 800aa13

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 4804156 + d032408 commit 800aa13

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/pipelines/bria/test_pipeline_bria.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
)
2929
from diffusers.pipelines.bria import BriaPipeline
3030
from diffusers.utils.testing_utils import (
31+
backend_empty_cache,
3132
enable_full_determinism,
3233
numpy_cosine_similarity_distance,
33-
require_accelerator,
34-
require_torch_gpu,
34+
require_torch_accelerator,
3535
slow,
3636
torch_device,
3737
)
@@ -149,7 +149,7 @@ def test_image_output_shape(self):
149149
assert (output_height, output_width) == (expected_height, expected_width)
150150

151151
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
152-
@require_accelerator
152+
@require_torch_accelerator
153153
def test_save_load_float16(self, expected_max_diff=1e-2):
154154
components = self.get_dummy_components()
155155
for name, module in components.items():
@@ -237,20 +237,20 @@ def test_torch_dtype_dict(self):
237237

238238

239239
@slow
240-
@require_torch_gpu
240+
@require_torch_accelerator
241241
class BriaPipelineSlowTests(unittest.TestCase):
242242
pipeline_class = BriaPipeline
243243
repo_id = "briaai/BRIA-3.2"
244244

245245
def setUp(self):
246246
super().setUp()
247247
gc.collect()
248-
torch.cuda.empty_cache()
248+
backend_empty_cache(torch_device)
249249

250250
def tearDown(self):
251251
super().tearDown()
252252
gc.collect()
253-
torch.cuda.empty_cache()
253+
backend_empty_cache(torch_device)
254254

255255
def get_inputs(self, device, seed=0):
256256
generator = torch.Generator(device="cpu").manual_seed(seed)

0 commit comments

Comments
 (0)