Skip to content

Commit a64200f

Browse files
author
zenghsh3
committed
refine unittest of sampling_id op
1 parent d749583 commit a64200f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

python/paddle/fluid/tests/unittests/test_sampling_id_op.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def setUp(self):
2525
self.op_type = "sampling_id"
2626
self.use_mkldnn = False
2727
self.init_kernel_type()
28-
self.X = np.random.random((8, 4)).astype('float32')
28+
self.X = np.random.random((100, 10)).astype('float32')
2929
self.inputs = {"X": self.X}
30-
self.Y = np.random.random(8).astype('float32')
30+
self.Y = np.random.random(100).astype('float32')
3131
self.outputs = {'Out': self.Y}
3232
self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1}
3333

@@ -36,6 +36,16 @@ def test_check_output(self):
3636
y1 = self.out
3737
self.check_output_customized(self.verify_output)
3838
y2 = self.out
39+
40+
# check dtype
41+
assert y1.dtype == np.int64
42+
assert y2.dtype == np.int64
43+
44+
# check output is index ids of inputs
45+
inputs_ids = np.arange(self.X.shape[1])
46+
assert np.isin(y1, inputs_ids).all()
47+
assert np.isin(y2, inputs_ids).all()
48+
3949
self.assertTrue(np.array_equal(y1, y2))
4050
self.assertEqual(len(y1), len(self.Y))
4151

0 commit comments

Comments
 (0)