Skip to content

Commit 680b1c6

Browse files
authored
Fix two tests for custom device (#76116)
* fix some tests
1 parent 4cd0d69 commit 680b1c6

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

test/legacy_test/test_rnn_decode_api.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
import numpy as np
19-
from op_test import is_custom_device
19+
from op_test import get_device_place, is_custom_device
2020

2121
import paddle
2222
from paddle import Model, base, nn, set_device
@@ -337,11 +337,10 @@ def check_output_with_place(self, place, mode="test"):
337337
)
338338

339339
def check_output(self):
340-
devices = (
341-
["CPU", "GPU"]
342-
if (base.is_compiled_with_cuda() or is_custom_device())
343-
else ["CPU"]
344-
)
340+
devices = ["CPU"]
341+
if base.is_compiled_with_cuda() or is_custom_device():
342+
devices.append(get_device_place())
343+
345344
for device in devices:
346345
place = set_device(device)
347346
self.check_output_with_place(place)

test/legacy_test/test_set_value_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,15 +1801,15 @@ def test_value_input_is_scalar(self):
18011801
)
18021802
class TestSetValueWithStrideError(unittest.TestCase):
18031803
def test_same_place(self):
1804-
x = paddle.rand([5, 10], device=paddle.CUDAPlace(0))
1805-
y = paddle.rand([10, 5], device=paddle.CUDAPlace(0))
1804+
x = paddle.rand([5, 10], device=get_device_place())
1805+
y = paddle.rand([10, 5], device=get_device_place())
18061806
y.transpose_([1, 0])
18071807
x.set_value(y)
18081808
assert x.is_contiguous()
18091809

18101810
def test_different_place1(self):
18111811
# src place != dst place && src is not contiguous
1812-
x = paddle.rand([5, 10], device=paddle.CUDAPlace(0))
1812+
x = paddle.rand([5, 10], device=get_device_place())
18131813
y = paddle.rand([10, 5], device=paddle.CPUPlace())
18141814
y.transpose_([1, 0])
18151815
x.set_value(y)
@@ -1818,7 +1818,7 @@ def test_different_place1(self):
18181818
def test_different_place2(self):
18191819
# src place != dst place && dst is not contiguous
18201820
with self.assertRaises(SystemError):
1821-
x = paddle.ones([5, 4], device=paddle.CUDAPlace(0))
1821+
x = paddle.ones([5, 4], device=get_device_place())
18221822
x.transpose_([1, 0])
18231823
y = paddle.rand([4, 2], device=paddle.CPUPlace())
18241824
assert not x.is_contiguous()

0 commit comments

Comments
 (0)