Skip to content

Commit 10eb128

Browse files
committed
[API-Compat] Updated unittests
1 parent 7aae63a commit 10eb128

File tree

2 files changed

+0
-147
lines changed

2 files changed

+0
-147
lines changed

paddle/phi/kernels/cpu/min_max_with_index_kernel.cc

Lines changed: 0 additions & 96 deletions
This file was deleted.

test/legacy_test/test_minmax_with_index_op.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -231,56 +231,5 @@ def test_check_grad(self):
231231
pass
232232

233233

234-
class TestMinMaxWithIndexPlace(unittest.TestCase):
235-
"""min/max_with_index has no CPU version, so when CUDA is not available,
236-
we skip all the above test. A runtime error will be emitted if min/max_with_index
237-
is called on CPU, this unit test tries capturing it.
238-
"""
239-
240-
def init(self):
241-
self.input_shape = [30, 10, 10]
242-
self.data = np.random.randn(30, 10, 10)
243-
244-
def setUp(self):
245-
self.init()
246-
247-
def cpu_place(self):
248-
self.place = core.CPUPlace()
249-
250-
def test_api_static_cpu_err_handling_1(self):
251-
self.cpu_place()
252-
with (
253-
self.assertRaises(RuntimeError),
254-
paddle.static.program_guard(paddle.static.Program()),
255-
):
256-
input = paddle.static.data(
257-
name="input", shape=self.input_shape, dtype="float64"
258-
)
259-
output = max_with_index(input, dim=0)
260-
exe = paddle.static.Executor(self.place)
261-
result = exe.run(
262-
paddle.static.default_main_program(),
263-
feed={'input': self.data},
264-
fetch_list=[output],
265-
)
266-
267-
def test_api_static_cpu_err_handling_2(self):
268-
self.cpu_place()
269-
with (
270-
self.assertRaises(RuntimeError),
271-
paddle.static.program_guard(paddle.static.Program()),
272-
):
273-
input = paddle.static.data(
274-
name="input", shape=self.input_shape, dtype="float32"
275-
)
276-
output = min_with_index(input, dim=-2, keepdim=True)
277-
exe = paddle.static.Executor(self.place)
278-
result = exe.run(
279-
paddle.static.default_main_program(),
280-
feed={'input': self.data.astype(np.float32)},
281-
fetch_list=[output],
282-
)
283-
284-
285234
if __name__ == "__main__":
286235
unittest.main()

0 commit comments

Comments
 (0)