Skip to content

Commit e00bcca

Browse files
committed
add one more caswe
1 parent 3803d93 commit e00bcca

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tests/pipelines/animatediff/test_animatediff_controlnet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def test_inference_batch_single_identical(
281281
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
282282
assert max_diff < expected_max_diff
283283

284-
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
285284
def test_to_device(self):
286285
components = self.get_dummy_components()
287286
pipe = self.pipeline_class(**components)
@@ -297,14 +296,14 @@ def test_to_device(self):
297296
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
298297
self.assertTrue(np.isnan(output_cpu).sum() == 0)
299298

300-
pipe.to("cuda")
299+
pipe.to(torch_device)
301300
model_devices = [
302301
component.device.type for component in pipe.components.values() if hasattr(component, "device")
303302
]
304-
self.assertTrue(all(device == "cuda" for device in model_devices))
303+
self.assertTrue(all(device == torch_device for device in model_devices))
305304

306-
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
307-
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
305+
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
306+
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
308307

309308
def test_to_dtype(self):
310309
components = self.get_dummy_components()

0 commit comments

Comments
 (0)