Skip to content

Commit d9c4f1e

Browse files
committed
update
1 parent 3b079ec commit d9c4f1e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,14 +1387,14 @@ def test_float16_inference(self, expected_max_diff=5e-2):
13871387
if "generator" in inputs:
13881388
inputs["generator"] = self.get_generator(0)
13891389

1390-
output = pipe(**inputs)[0]
1390+
output = pipe(**inputs)[0].cpu()
13911391

13921392
fp16_inputs = self.get_dummy_inputs(torch_device)
13931393
# Reset generator in case it is used inside dummy inputs
13941394
if "generator" in fp16_inputs:
13951395
fp16_inputs["generator"] = self.get_generator(0)
13961396

1397-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
1397+
output_fp16 = pipe_fp16(**fp16_inputs)[0].cpu()
13981398
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
13991399
assert max_diff < 1e-2
14001400

0 commit comments

Comments
 (0)