Skip to content

Commit fbdd4f8

Browse files
authored
Merge pull request #13101 from zenghsh3/develop
Fix bug of sampling_id op
2 parents 9bd933d + d0c01d3 commit fbdd4f8

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

paddle/fluid/operators/sampling_id_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
5353
static_cast<T>(context.Attr<float>("min")),
5454
static_cast<T>(context.Attr<float>("max")));
5555

56-
std::vector<T> ids(batch_size);
56+
std::vector<int64_t> ids(batch_size);
5757
for (int i = 0; i < batch_size; ++i) {
5858
T r = dist(engine);
5959
int idx = width - 1;
@@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel<T> {
6363
break;
6464
}
6565
}
66-
ids[i] = ins_vector[idx];
66+
ids[i] = int64_t(idx);
6767
}
6868

6969
std::vector<int64_t> out_dim;

python/paddle/fluid/tests/unittests/test_sampling_id_op.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def setUp(self):
2525
self.op_type = "sampling_id"
2626
self.use_mkldnn = False
2727
self.init_kernel_type()
28-
self.X = np.random.random((8, 4)).astype('float32')
28+
self.X = np.random.random((100, 10)).astype('float32')
2929
self.inputs = {"X": self.X}
30-
self.Y = np.random.random(8).astype('float32')
30+
self.Y = np.random.random(100).astype('int64')
3131
self.outputs = {'Out': self.Y}
3232
self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1}
3333

@@ -36,6 +36,16 @@ def test_check_output(self):
3636
y1 = self.out
3737
self.check_output_customized(self.verify_output)
3838
y2 = self.out
39+
40+
# check dtype
41+
assert y1.dtype == np.int64
42+
assert y2.dtype == np.int64
43+
44+
# check output is index ids of inputs
45+
inputs_ids = np.arange(self.X.shape[1])
46+
assert np.isin(y1, inputs_ids).all()
47+
assert np.isin(y2, inputs_ids).all()
48+
3949
self.assertTrue(np.array_equal(y1, y2))
4050
self.assertEqual(len(y1), len(self.Y))
4151

0 commit comments

Comments
 (0)