Skip to content

Commit 509cb0b

Browse files
committed
add unit test, pass the unit test
1 parent 7cebec4 commit 509cb0b

File tree

3 files changed

+64
-9
lines changed

3 files changed

+64
-9
lines changed

paddle/fluid/operators/merge_ids_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ class MergeIdsOp : public framework::OperatorWithKernel {
7373
}
7474
ctx->ShareLoD("Ids", "Out");
7575
}
76+
77+
private:
78+
framework::OpKernelType GetExpectedKernelType(
79+
const framework::ExecutionContext &ctx) const override {
80+
return framework::OpKernelType(
81+
framework::ToDataType(
82+
ctx.MultiInput<framework::Tensor>("X").front()->type()),
83+
ctx.GetPlace());
84+
}
7685
};
7786

7887
class MergeIdsOpInferVarType : public framework::VarTypeInference {
@@ -93,5 +102,4 @@ namespace ops = paddle::operators;
93102
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
94103
ops::MergeIdsOpInferVarType);
95104
REGISTER_OP_CPU_KERNEL(
96-
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
97-
ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
105+
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);

paddle/fluid/operators/merge_ids_op.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
3030
if (!platform::is_cpu_place(place)) {
3131
PADDLE_THROW("MergeIds do not support GPU kernel");
3232
}
33+
VLOG(3) << "run in MergeIdsOpKernel";
3334

3435
const auto *ids_var = ctx.InputVar("Ids");
3536
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
3637
"only support to merge Ids of LoDTensor");
3738

3839
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
3940
const auto &ids_dims = ids_tensor.dims();
40-
const T *ids = ids_tensor.data<T>();
41+
const int64_t *ids = ids_tensor.data<int64_t>();
4142

4243
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
4344

@@ -49,9 +50,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
4950
if (embedding_size == 0) {
5051
embedding_size = input->dims()[1];
5152
}
52-
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
53-
"embedding size of all input should be the same");
54-
batch_size += input->dims()[0];
53+
if (framework::product(input->dims()) != 0) {
54+
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
55+
"embedding size of all input should be the same");
56+
batch_size += input->dims()[0];
57+
}
5558
}
5659
PADDLE_ENFORCE_EQ(
5760
batch_size, ids_dims[0],
@@ -61,20 +64,26 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
6164

6265
if (shard_num == 1) {
6366
VLOG(3) << "only one shard, we can copy the data directly";
64-
TensorCopy(ids_tensor, place, out);
67+
TensorCopy(*x_tensors[0], place, out);
6568
} else {
6669
std::vector<int> in_indexs(shard_num, 0);
67-
auto *out_data = out->mutable_data<T>(ids_dims, place);
70+
auto *out_data = out->mutable_data<T>(
71+
framework::make_ddim({batch_size, embedding_size}), place);
6872
// copy data from ins[shard_num] to out.
6973
for (int i = 0; i < ids_dims[0]; ++i) {
70-
T id = ids[i];
74+
int64_t id = ids[i];
7175
size_t shard_id = static_cast<size_t>(id) % shard_num;
7276
int index = in_indexs[shard_id];
7377
memcpy(out_data + embedding_size * i,
7478
x_tensors[shard_id]->data<T>() + index * embedding_size,
7579
sizeof(T) * embedding_size);
7680
in_indexs[shard_id] += 1;
7781
}
82+
83+
for (int i = 0; i < shard_num; ++i) {
84+
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
85+
"after merge, all data in x_tensor should be used");
86+
}
7887
}
7988
}
8089
};
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
class TestMergeIdsOp(OpTest):
21+
def setUp(self):
22+
self.op_type = "merge_ids"
23+
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
24+
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
25+
x1 = np.array([]).astype('float32')
26+
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
27+
[0.5, 0.6]]).astype('float32')
28+
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
29+
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
30+
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
31+
self.outputs = {'Out': out}
32+
33+
def test_check_output(self):
34+
self.check_output()
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

0 commit comments

Comments
 (0)