|
11 | 11 | UNet2DModel, |
12 | 12 | ) |
13 | 13 | from diffusers.utils.testing_utils import ( |
| 14 | + Expectations, |
| 15 | + backend_empty_cache, |
14 | 16 | enable_full_determinism, |
15 | 17 | nightly, |
16 | 18 | require_torch_2, |
17 | | - require_torch_gpu, |
| 19 | + require_torch_accelerator, |
18 | 20 | torch_device, |
19 | 21 | ) |
20 | 22 | from diffusers.utils.torch_utils import randn_tensor |
@@ -168,17 +170,17 @@ def test_consistency_model_pipeline_onestep_class_cond(self): |
168 | 170 |
|
169 | 171 |
|
170 | 172 | @nightly |
171 | | -@require_torch_gpu |
| 173 | +@require_torch_accelerator |
172 | 174 | class ConsistencyModelPipelineSlowTests(unittest.TestCase): |
173 | 175 | def setUp(self): |
174 | 176 | super().setUp() |
175 | 177 | gc.collect() |
176 | | - torch.cuda.empty_cache() |
| 178 | + backend_empty_cache(torch_device) |
177 | 179 |
|
178 | 180 | def tearDown(self): |
179 | 181 | super().tearDown() |
180 | 182 | gc.collect() |
181 | | - torch.cuda.empty_cache() |
| 183 | + backend_empty_cache(torch_device) |
182 | 184 |
|
183 | 185 | def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): |
184 | 186 | generator = torch.manual_seed(seed) |
@@ -264,11 +266,19 @@ def test_consistency_model_cd_multistep_flash_attn(self): |
264 | 266 | # Ensure usage of flash attention in torch 2.0 |
265 | 267 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): |
266 | 268 | image = pipe(**inputs).images |
| 269 | + |
267 | 270 | assert image.shape == (1, 64, 64, 3) |
268 | 271 |
|
269 | 272 | image_slice = image[0, -3:, -3:, -1] |
270 | 273 |
|
271 | | - expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) |
| 274 | + expected_slices = Expectations( |
| 275 | + { |
| 276 | + ("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), |
| 277 | + ("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]), |
| 278 | + ("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), |
| 279 | + } |
| 280 | + ) |
| 281 | + expected_slice = expected_slices.get_expectation() |
272 | 282 |
|
273 | 283 | assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 |
274 | 284 |
|
|
0 commit comments