@@ -30,14 +30,15 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
30
30
if (!platform::is_cpu_place (place)) {
31
31
PADDLE_THROW (" MergeIds do not support GPU kernel" );
32
32
}
33
+ VLOG (3 ) << " run in MergeIdsOpKernel" ;
33
34
34
35
const auto *ids_var = ctx.InputVar (" Ids" );
35
36
PADDLE_ENFORCE (ids_var->IsType <framework::LoDTensor>(),
36
37
" only support to merge Ids of LoDTensor" );
37
38
38
39
const auto &ids_tensor = ids_var->Get <framework::LoDTensor>();
39
40
const auto &ids_dims = ids_tensor.dims ();
40
- const T *ids = ids_tensor.data <T >();
41
+ const int64_t *ids = ids_tensor.data <int64_t >();
41
42
42
43
auto x_tensors = ctx.MultiInput <framework::LoDTensor>(" X" );
43
44
@@ -49,9 +50,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
49
50
if (embedding_size == 0 ) {
50
51
embedding_size = input->dims ()[1 ];
51
52
}
52
- PADDLE_ENFORCE_EQ (embedding_size, input->dims ()[1 ],
53
- " embedding size of all input should be the same" );
54
- batch_size += input->dims ()[0 ];
53
+ if (framework::product (input->dims ()) != 0 ) {
54
+ PADDLE_ENFORCE_EQ (embedding_size, input->dims ()[1 ],
55
+ " embedding size of all input should be the same" );
56
+ batch_size += input->dims ()[0 ];
57
+ }
55
58
}
56
59
PADDLE_ENFORCE_EQ (
57
60
batch_size, ids_dims[0 ],
@@ -61,20 +64,26 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
61
64
62
65
if (shard_num == 1 ) {
63
66
VLOG (3 ) << " only one shard, we can copy the data directly" ;
64
- TensorCopy (ids_tensor , place, out);
67
+ TensorCopy (*x_tensors[ 0 ] , place, out);
65
68
} else {
66
69
std::vector<int > in_indexs (shard_num, 0 );
67
- auto *out_data = out->mutable_data <T>(ids_dims, place);
70
+ auto *out_data = out->mutable_data <T>(
71
+ framework::make_ddim ({batch_size, embedding_size}), place);
68
72
// copy data from ins[shard_num] to out.
69
73
for (int i = 0 ; i < ids_dims[0 ]; ++i) {
70
- T id = ids[i];
74
+ int64_t id = ids[i];
71
75
size_t shard_id = static_cast <size_t >(id) % shard_num;
72
76
int index = in_indexs[shard_id];
73
77
memcpy (out_data + embedding_size * i,
74
78
x_tensors[shard_id]->data <T>() + index * embedding_size,
75
79
sizeof (T) * embedding_size);
76
80
in_indexs[shard_id] += 1 ;
77
81
}
82
+
83
+ for (int i = 0 ; i < shard_num; ++i) {
84
+ PADDLE_ENFORCE_EQ (in_indexs[i], x_tensors[i]->dims ()[0 ],
85
+ " after merge, all data in x_tensor should be used" );
86
+ }
78
87
}
79
88
}
80
89
};
0 commit comments