Skip to content

Commit 08b73b6

Browse files
author
zenghsh3
committed
fix bug of sampling_id_op
1 parent 823c4f8 commit 08b73b6

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

paddle/fluid/operators/sampling_id_op.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ class SamplingIdKernel : public framework::OpKernel<T> {
5353
static_cast<T>(context.Attr<float>("min")),
5454
static_cast<T>(context.Attr<float>("max")));
5555

56+
<<<<<<< HEAD
57+
std::vector<int64_t> ids(batch_size);
58+
for (size_t i = 0; i < batch_size; ++i) {
59+
=======
5660
std::vector<T> ids(batch_size);
5761
for (int i = 0; i < batch_size; ++i) {
62+
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
5863
T r = dist(engine);
5964
int idx = width - 1;
6065
for (int j = 0; j < width; ++j) {
@@ -63,7 +68,11 @@ class SamplingIdKernel : public framework::OpKernel<T> {
6368
break;
6469
}
6570
}
71+
<<<<<<< HEAD
72+
ids[i] = int64_t(idx);
73+
=======
6674
ids[i] = ins_vector[idx];
75+
>>>>>>> 823c4f87beff04e4029e3f4a183658621ca8f01b
6776
}
6877

6978
std::vector<int64_t> out_dim;

0 commit comments

Comments
 (0)