Skip to content

Commit 4cd504d

Browse files
committed
bug fix
1 parent da2cc99 commit 4cd504d

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

paddle/fluid/operators/sampling_id_op.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#pragma once
1515

16+
#include <algorithm>
17+
#include <iostream>
18+
#include <iterator>
1619
#include <random>
20+
#include <sstream>
1721
#include <vector>
1822
#include "paddle/fluid/framework/lod_tensor.h"
1923
#include "paddle/fluid/framework/op_registry.h"
@@ -34,17 +38,17 @@ class SamplingIdKernel : public framework::OpKernel<T> {
3438
std::vector<T> ins_vector;
3539
framework::TensorToVector(*input, context.device_context(), &ins_vector);
3640

37-
std::vector<int> ids(batch_size);
41+
std::vector<T> ids(batch_size);
3842
for (size_t i = 0; i < batch_size; ++i) {
3943
double r = this->get_rand();
40-
int id = width - 1;
44+
int idx = width - 1;
4145
for (int j = 0; j < width; ++j) {
4246
if ((r -= ins_vector[i * width + j]) < 0) {
43-
id = j;
47+
idx = j;
4448
break;
4549
}
4650
}
47-
ids[i] = id;
51+
ids[i] = ins_vector[i * width + idx];
4852
}
4953

5054
std::vector<int64_t> out_dim;

python/paddle/fluid/tests/unittests/test_sampling_id_op.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ def setUp(self):
2525
self.op_type = "sampling_id"
2626
self.use_mkldnn = False
2727
self.init_kernel_type()
28-
X = np.random.random((3, 4)).astype('float32')
29-
self.inputs = {"X": X}
30-
Y = np.random.random(3).astype('float32')
31-
self.outputs = {'Out': Y}
28+
self.X = np.random.random((8, 4)).astype('float32')
29+
self.inputs = {"X": self.X}
30+
self.Y = np.random.random(8).astype('float32')
31+
self.outputs = {'Out': self.Y}
3232
self.attrs = {'use_mkldnn': self.use_mkldnn}
3333

3434
def test_check_output(self):
35-
self.check_output()
35+
self.check_output_customized(self.verify_output)
3636

37-
def test_check_grad(self):
38-
self.check_grad(['X'], 'Out')
37+
def verify_output(self, outs):
38+
out = np.array(outs[0])
39+
self.assertEqual(len(out), len(self.Y))
3940

4041
def init_kernel_type(self):
4142
pass

0 commit comments

Comments
 (0)