@@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
43
43
PADDLE_ENFORCE_EQ (ids.size (), outs.size (),
44
44
" the number of Ids and Out should be the same" );
45
45
46
- size_t row_ids_size = 0 ;
47
- int row_size = 0 ;
48
- int embedding_size = 0 ;
46
+ int64_t row_ids_size = 0 ;
47
+ int64_t row_size = 0 ;
48
+ int64_t embedding_size = 0 ;
49
49
50
50
for (size_t i = 0 ; i < x_tensors.size (); ++i) {
51
51
const auto *x_tensor = x_tensors[i];
@@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
69
69
for (size_t i = 0 ; i < x_tensors.size (); ++i) {
70
70
const auto *row_id = row_ids[i];
71
71
72
- for (int j = 0 ; j < row_id->numel (); ++j) {
72
+ for (auto j = 0 ; j < row_id->numel (); ++j) {
73
73
int64_t key = row_id->data <int64_t >()[j];
74
74
std::tuple<int64_t , int64_t > val = std::make_tuple (i, j);
75
75
selected_rows_idx_map.insert (std::make_pair (key, val));
@@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
84
84
85
85
out->set_lod (out_ids->lod ());
86
86
87
- int nums = static_cast < int >( out_ids->dims ()[0 ]) ;
87
+ auto nums = out_ids->dims ()[0 ];
88
88
auto *out_data = out->mutable_data <T>(
89
89
framework::make_ddim ({nums, embedding_size}), place);
90
- for (int j = 0 ; j < nums; ++j) {
91
- int id = out_ids->data <int64_t >()[j];
92
- auto row_tuple = selected_rows_idx_map[id] ;
93
- int64_t row_idx = std::get<1 >(row_tuple);
90
+ for (auto j = 0 ; j < nums; ++j) {
91
+ auto id = out_ids->data <int64_t >()[j];
92
+ auto row_tuple = selected_rows_idx_map. at (id) ;
93
+ auto row_idx = std::get<1 >(row_tuple);
94
94
const auto *x_tensor = x_tensors[std::get<0 >(row_tuple)];
95
95
96
96
memcpy (out_data + embedding_size * j,
0 commit comments