Skip to content

Commit d6c8d26

Browse files
committed
optimize code and comment
1 parent f031555 commit d6c8d26

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

paddle/fluid/operators/merge_ids_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
2121
public:
2222
void Make() override {
2323
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
24-
AddInput("X",
25-
"(LoDTensor) the input tensor with shape{batch_num, N}, N is the "
26-
"size of embedding table")
24+
AddInput(
25+
"X",
26+
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
27+
"size of embedding table")
2728
.AsDuplicable();
2829
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
2930

3031
AddComment(R"DOC(
3132
Merge multi LoDTensor's into one according to Ids's shard num.
32-
The values in the input LoDTensor are lookuped from the output of splite_ids_op
33+
The values in the input LoDTensor are lookuped from the output of split_ids_op
34+
3335
Example:
3436
Input:
3537
Ids = [1,2,3,4,5,6]

paddle/fluid/operators/merge_ids_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
4747
int batch_size = 0;
4848
int embedding_size = 0;
4949
for (auto &input : x_tensors) {
50-
if (embedding_size == 0) {
51-
embedding_size = input->dims()[1];
52-
}
5350
if (framework::product(input->dims()) != 0) {
51+
if (embedding_size == 0) {
52+
embedding_size = input->dims()[1];
53+
}
5454
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
5555
"embedding size of all input should be the same");
5656
batch_size += input->dims()[0];
5757
}
5858
}
5959
PADDLE_ENFORCE_EQ(
6060
batch_size, ids_dims[0],
61-
"the batch size of ids and embedding value should be the same");
61+
"the batch size of ids and merged embedding value should be the same");
6262

6363
const size_t shard_num = x_tensors.size();
6464

@@ -80,7 +80,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
8080
in_indexs[shard_id] += 1;
8181
}
8282

83-
for (int i = 0; i < shard_num; ++i) {
83+
for (size_t i = 0; i < shard_num; ++i) {
8484
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
8585
"after merge, all data in x_tensor should be used");
8686
}

0 commit comments

Comments
 (0)