Skip to content

Commit 41693b6

Browse files
committed
optimize code
1 parent 91f63cd commit 41693b6

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

python/paddle/fluid/tests/unittests/test_split_ids_op.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,21 @@ def check_with_place(self, place):
6565
]
6666

6767
# expected output selected rows
68-
expected_out0_rows = [0, 9]
69-
expected_out1_rows = [7, 4]
70-
expected_out2_rows = [5]
68+
expected_out_rows = [[0, 9], [7, 4], [5]]
7169

7270
op = Operator("split_ids", Ids="X", Out=outs_name)
7371

7472
op.run(scope, place)
7573

76-
self.assertEqual(outs[0].rows(), expected_out0_rows)
77-
self.assertEqual(outs[1].rows(), expected_out1_rows)
78-
self.assertEqual(outs[2].rows(), expected_out2_rows)
79-
80-
self.assertAlmostEqual(0.0, np.array(outs[0].get_tensor())[0, 0])
81-
self.assertAlmostEqual(1.0, np.array(outs[0].get_tensor())[0, 1])
82-
self.assertAlmostEqual(9.0, np.array(outs[0].get_tensor())[1, 0])
83-
self.assertAlmostEqual(10.0, np.array(outs[0].get_tensor())[1, 1])
84-
85-
self.assertAlmostEqual(7.0, np.array(outs[1].get_tensor())[0, 0])
86-
self.assertAlmostEqual(8.0, np.array(outs[1].get_tensor())[0, 1])
87-
self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 0])
88-
self.assertAlmostEqual(5.0, np.array(outs[1].get_tensor())[1, 1])
89-
90-
self.assertAlmostEqual(5.0, np.array(outs[2].get_tensor())[0, 0])
91-
self.assertAlmostEqual(6.0, np.array(outs[2].get_tensor())[0, 1])
74+
for i in range(len(outs)):
75+
expected_rows = expected_out_rows[i]
76+
self.assertEqual(outs[i].rows(), expected_rows)
77+
for j in range(len(expected_rows)):
78+
row = expected_rows[j]
79+
self.assertAlmostEqual(
80+
float(row), np.array(outs[i].get_tensor())[j, 0])
81+
self.assertAlmostEqual(
82+
float(row + 1), np.array(outs[i].get_tensor())[j, 1])
9283

9384

9485
if __name__ == '__main__':

0 commit comments

Comments
 (0)