Skip to content

Commit f1c3ecb

Browse files
committed
add concat rows
1 parent 1509ce6 commit f1c3ecb

File tree

4 files changed

+136
-19
lines changed

4 files changed

+136
-19
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
3434
auto ids_dims = ctx->GetInputDim("Ids");
3535

3636
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
37-
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
38-
// Maybe near future we will add concat_rows op.
37+
// lookup_table and concat_rows use the same InferShape, for lookup_table,
38+
// ids_var_type should be LoDTensor, for concat_rows, it should be
39+
// SelectedRows.
3940
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
4041
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
4142
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
@@ -90,6 +91,44 @@ or not. And the output only shares the LoD information with input Ids.
9091
}
9192
};
9293

94+
class ConcatRowsOpMaker : public framework::OpProtoAndCheckerMaker {
95+
public:
96+
ConcatRowsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
97+
: OpProtoAndCheckerMaker(proto, op_checker) {
98+
AddInput("W",
99+
"(Tensor) The input tensor of concat_rows operator. "
100+
"The rank of this tensor is 2.");
101+
AddInput(
102+
"Ids",
103+
"(SelectedRows) The rows of Ids contains the index to be looked up "
104+
"in W.");
105+
AddOutput("Out",
106+
"(SelectedRows or Tensor) The result of concatenating, which "
107+
"have the same type as W.");
108+
AddAttr<bool>("is_sparse",
109+
"(boolean, default true) This attribution is invalid, it's "
110+
"only used by `Lookup Table Operator`.")
111+
.SetDefault(true);
112+
AddAttr<int64_t>("padding_idx",
113+
"(int64, default -1) "
114+
"If the value is -1, it makes no effect to lookup. "
115+
"Otherwise the given value indicates padding the output "
116+
"with zeros whenever lookup encounters it in Ids.")
117+
.SetDefault(-1);
118+
119+
AddComment(R"DOC(
120+
ConcatRows Operator.
121+
122+
This operator is used to perform lookups on the W(dense tensor) according to
123+
rows contained by Idx(sparse tensor), then concatenates them into a sparse
124+
tensor or dense tensor.
125+
126+
The type of Ids(Input) is SelectedRows.
127+
128+
)DOC");
129+
}
130+
};
131+
93132
class LookupTableOpGradDescMaker
94133
: public framework::DefaultGradOpDescMaker<true> {
95134
using ::paddle::framework::DefaultGradOpDescMaker<
@@ -150,3 +189,8 @@ REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
150189
ops::LookupTableKernel<double>);
151190
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
152191
ops::LookupTableGradKernel<double>);
192+
193+
// concat_rows is used by regularization and it doesn't have gradient operation.
194+
REGISTER_OPERATOR(concat_rows, ops::LookupTableOp, ops::ConcatRowsOpMaker);
195+
REGISTER_OP_CPU_KERNEL(concat_rows, ops::LookupTableKernel<float>,
196+
ops::LookupTableKernel<double>);

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
7979

8080
int64_t* ids;
8181
int64_t K;
82-
framework::Tensor* output_t;
82+
auto* output_t = context.Output<Tensor>("Out"); // float tensor;
8383

84-
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
85-
// Maybe near future we will add concat_rows op.
86-
if (ids_var->IsType<framework::LoDTensor>()) {
84+
// lookup_table and concat_rows use the same kernel, for lookup_table,
85+
// ids_var_type should be LoDTensor, for concat_rows, ids_var_type and
86+
// out_var_type should be SelectedRows.
87+
if (ids_var->IsType<LoDTensor>()) {
8788
auto* ids_t = context.Input<LoDTensor>("Ids");
88-
output_t = context.Output<LoDTensor>("Out"); // float tensor
8989
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
9090
K = ids_t->numel();
91-
} else if (ids_var->IsType<framework::SelectedRows>()) {
92-
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
93-
output_t = const_cast<framework::Tensor*>(
94-
&(context.Output<framework::SelectedRows>("Out")
95-
->value())); // float tensor
91+
} else if (ids_var->IsType<SelectedRows>()) {
92+
auto* ids_t = context.Input<SelectedRows>("Ids");
9693
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
9794
K = ids_t->rows().size();
9895
output_t->Resize({K, table_t->dims()[1]});
@@ -194,3 +191,6 @@ REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
194191
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
195192
ops::LookupTableGradCUDAKernel<float>,
196193
ops::LookupTableGradCUDAKernel<double>);
194+
195+
REGISTER_OP_CUDA_KERNEL(concat_rows, ops::LookupTableCUDAKernel<float>,
196+
ops::LookupTableCUDAKernel<double>);

paddle/fluid/operators/lookup_table_op.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,16 @@ class LookupTableKernel : public framework::OpKernel<T> {
3535

3636
int64_t* ids;
3737
int64_t ids_numel;
38-
Tensor* output_t;
39-
40-
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
41-
// Maybe near future we will add concat_rows op.
38+
auto* output_t = context.Output<Tensor>("Out");
39+
// lookup_table and concat_rows use the same kernel, for lookup_table,
40+
// ids_var_type should be LoDTensor, for concat_rows, ids_var_type and
41+
// out_var_type should be SelectedRows.
4242
if (ids_var->IsType<LoDTensor>()) {
4343
auto* ids_t = context.Input<LoDTensor>("Ids");
44-
output_t = context.Output<LoDTensor>("Out");
4544
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
4645
ids_numel = ids_t->numel();
4746
} else if (ids_var->IsType<SelectedRows>()) {
4847
auto* ids_t = context.Input<SelectedRows>("Ids");
49-
output_t =
50-
const_cast<Tensor*>(&(context.Output<SelectedRows>("Out")->value()));
5148
ids = const_cast<int64_t*>(ids_t->rows().data());
5249
ids_numel = ids_t->rows().size();
5350
output_t->Resize({ids_numel, table_t->dims()[1]});
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
import paddle.fluid.core as core
18+
from paddle.fluid.op import Operator
19+
from op_test import OpTest
20+
21+
22+
class TestConcatRowsOp(OpTest):
23+
def check_with_place(self, place):
24+
scope = core.Scope()
25+
26+
# create and initialize Grad Variable
27+
height = 10
28+
rows = [0, 4, 4, 7]
29+
row_numel = 12
30+
31+
ids_selected_rows = scope.var('Ids').get_selected_rows()
32+
ids_selected_rows.set_height(height)
33+
ids_selected_rows.set_rows(rows)
34+
np_array = np.ones((len(rows), row_numel)).astype("float32")
35+
ids_tensor = ids_selected_rows.get_tensor()
36+
ids_tensor.set(np_array, place)
37+
38+
# create and initialize W Variable
39+
W = scope.var('W').get_tensor()
40+
W_array = np.full((height, row_numel), 1.0).astype("float32")
41+
for i in range(height):
42+
W_array[i] *= i
43+
W.set(W_array, place)
44+
45+
Out = scope.var('Out').get_selected_rows()
46+
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
47+
Out.set_height(height)
48+
Out.set_rows(rows)
49+
Out_tensor = Out.get_tensor()
50+
Out_tensor.set(Out_array, place)
51+
52+
# create and run concat_rows_op operator
53+
concat_rows_op = Operator(
54+
"concat_rows",
55+
W='W',
56+
Ids='Ids',
57+
Out='Out',
58+
attrs={'is_sparse': True})
59+
concat_rows_op.run(scope, place)
60+
61+
# get and compare result
62+
result_array = np.array(Out_tensor)
63+
64+
for idx, row in enumerate(rows):
65+
assert (row == result_array[idx]).all()
66+
67+
def test_concat_rows(self):
68+
places = [core.CPUPlace()]
69+
if core.is_compiled_with_cuda():
70+
places.append(core.CUDAPlace(0))
71+
for place in places:
72+
self.check_with_place(place)
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)