File tree Expand file tree Collapse file tree 2 files changed +11
-9
lines changed Expand file tree Collapse file tree 2 files changed +11
-9
lines changed Original file line number Diff line number Diff line change @@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
21
21
public:
22
22
void Make () override {
23
23
AddInput (" Ids" , " (LoDTensor) the input ids with shape{batch_num, 1}" );
24
- AddInput (" X" ,
25
- " (LoDTensor) the input tensor with shape{batch_num, N}, N is the "
26
- " size of embedding table" )
24
+ AddInput (
25
+ " X" ,
26
+ " (LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
27
+ " size of embedding table" )
27
28
.AsDuplicable ();
28
29
AddOutput (" Out" , " (LoDTensor) The merged outputs of the input tensors." );
29
30
30
31
AddComment (R"DOC(
31
32
Merge multi LoDTensor's into one according to Ids's shard num.
32
- The values in the input LoDTensor are lookuped from the output of splite_ids_op
33
+ The values in the input LoDTensor are lookuped from the output of split_ids_op
34
+
33
35
Example:
34
36
Input:
35
37
Ids = [1,2,3,4,5,6]
Original file line number Diff line number Diff line change @@ -47,18 +47,18 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
47
47
int batch_size = 0 ;
48
48
int embedding_size = 0 ;
49
49
for (auto &input : x_tensors) {
50
- if (embedding_size == 0 ) {
51
- embedding_size = input->dims ()[1 ];
52
- }
53
50
if (framework::product (input->dims ()) != 0 ) {
51
+ if (embedding_size == 0 ) {
52
+ embedding_size = input->dims ()[1 ];
53
+ }
54
54
PADDLE_ENFORCE_EQ (embedding_size, input->dims ()[1 ],
55
55
" embedding size of all input should be the same" );
56
56
batch_size += input->dims ()[0 ];
57
57
}
58
58
}
59
59
PADDLE_ENFORCE_EQ (
60
60
batch_size, ids_dims[0 ],
61
- " the batch size of ids and embedding value should be the same" );
61
+ " the batch size of ids and merged embedding value should be the same" );
62
62
63
63
const size_t shard_num = x_tensors.size ();
64
64
@@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
80
80
in_indexs[shard_id] += 1 ;
81
81
}
82
82
83
- for (int i = 0 ; i < shard_num; ++i) {
83
+ for (size_t i = 0 ; i < shard_num; ++i) {
84
84
PADDLE_ENFORCE_EQ (in_indexs[i], x_tensors[i]->dims ()[0 ],
85
85
" after merge, all data in x_tensor should be used" );
86
86
}
You can’t perform that action at this time.
0 commit comments