Skip to content

Commit 981fc2b

Browse files
authored
fix bug in merge_ids (#15503)
* fix mistakes in merge_ids, test=develop
1 parent a7ba07d commit 981fc2b

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

paddle/fluid/operators/distributed_ops/merge_ids_op.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
4343
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
4444
"the number of Ids and Out should be the same");
4545

46-
size_t row_ids_size = 0;
47-
int row_size = 0;
48-
int embedding_size = 0;
46+
int64_t row_ids_size = 0;
47+
int64_t row_size = 0;
48+
int64_t embedding_size = 0;
4949

5050
for (size_t i = 0; i < x_tensors.size(); ++i) {
5151
const auto *x_tensor = x_tensors[i];
@@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
6969
for (size_t i = 0; i < x_tensors.size(); ++i) {
7070
const auto *row_id = row_ids[i];
7171

72-
for (int j = 0; j < row_id->numel(); ++j) {
72+
for (auto j = 0; j < row_id->numel(); ++j) {
7373
int64_t key = row_id->data<int64_t>()[j];
7474
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
7575
selected_rows_idx_map.insert(std::make_pair(key, val));
@@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
8484

8585
out->set_lod(out_ids->lod());
8686

87-
int nums = static_cast<int>(out_ids->dims()[0]);
87+
auto nums = out_ids->dims()[0];
8888
auto *out_data = out->mutable_data<T>(
8989
framework::make_ddim({nums, embedding_size}), place);
90-
for (int j = 0; j < nums; ++j) {
91-
int id = out_ids->data<int64_t>()[j];
92-
auto row_tuple = selected_rows_idx_map[id];
93-
int64_t row_idx = std::get<1>(row_tuple);
90+
for (auto j = 0; j < nums; ++j) {
91+
auto id = out_ids->data<int64_t>()[j];
92+
auto row_tuple = selected_rows_idx_map.at(id);
93+
auto row_idx = std::get<1>(row_tuple);
9494
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
9595

9696
memcpy(out_data + embedding_size * j,

0 commit comments

Comments
 (0)