Skip to content

Commit 788c600

Browse files
author
chengduo
authored
Merge pull request #8932 from chengduoZH/feature/add_concat_rows
Enhance look_up_table op
2 parents 2807896 + a43eee4 commit 788c600

File tree

4 files changed

+129
-23
lines changed

4 files changed

+129
-23
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
3333
auto table_dims = ctx->GetInputDim("W");
3434
auto ids_dims = ctx->GetInputDim("Ids");
3535

36-
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
37-
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
36+
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
37+
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
38+
// is LoDTensor, this tensor contains the ids to be looked up in W
39+
// and it must be a column vector with rank = 2 while the 2nd dimension
40+
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
41+
// contains the ids to be looked up in W;
42+
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
43+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
44+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
45+
}
3846

3947
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
4048
ctx->ShareLoD("Ids", /*->*/ "Out");
@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
5462
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
5563
: OpProtoAndCheckerMaker(proto, op_checker) {
5664
AddInput("W",
57-
"An input represents embedding tensors, "
65+
"(Tensor) The input represents embedding tensors, "
5866
"which is a learnable parameter.");
59-
AddInput("Ids",
60-
"An input with type int32 or int64 "
61-
"contains the ids to be looked up in W. "
62-
"Ids must be a column vector with rank = 2. "
63-
"The 2nd dimension size must be 1.");
64-
AddOutput("Out", "The lookup results, which have the same type as W.");
67+
AddInput(
68+
"Ids",
69+
"(Tensor or SelectedRows) Ids's type can be Tensor or "
70+
"SelectedRows, when Ids's type is Tensor, this tensor contains "
71+
"the ids to be looked up in W and it must be a column vector with "
72+
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
73+
"SelectedRows, the rows of Ids contains the ids to be looked up "
74+
"in W.");
75+
AddOutput("Out",
76+
"(Tensor or SelectedRows) The lookup results, which have the "
77+
"same type as W.");
6578
AddAttr<bool>("is_sparse",
6679
"(boolean, default false) "
67-
"Sparse update")
80+
"Sparse update.")
6881
.SetDefault(false);
6982
AddAttr<int64_t>("padding_idx",
7083
"(int64, default -1) "
@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
7689
Lookup Table Operator.
7790
7891
This operator is used to perform lookups on the parameter W,
79-
then concatenated into a dense tensor.
92+
then concatenated into a dense or sparse tensor.
93+
94+
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
95+
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
96+
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
97+
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
98+
at this time, Ids can carry the LoD (Level of Details) information, or not, and
99+
the output only shares the LoD information with input Ids.
80100
81-
The input Ids can carry the LoD (Level of Details) information,
82-
or not. And the output only shares the LoD information with input Ids.
83101
84102
)DOC");
85103
}

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
7474
public:
7575
void Compute(const framework::ExecutionContext& context) const override {
7676
auto* table_t = context.Input<LoDTensor>("W");
77-
auto* ids_t = context.Input<LoDTensor>("Ids");
78-
auto* output_t = context.Output<LoDTensor>("Out");
7977
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
78+
auto* ids_var = context.InputVar("Ids");
79+
Tensor* output_t = context.Output<Tensor>("Out");
80+
81+
int64_t* ids;
82+
int64_t K;
83+
84+
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
85+
// is LoDTensor, this tensor contains the ids to be looked up in W;
86+
// when Ids's type is SelectedRows, the rows of Ids contains the
87+
// ids to be looked up in W.
88+
if (ids_var->IsType<framework::LoDTensor>()) {
89+
auto* ids_t = context.Input<LoDTensor>("Ids");
90+
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
91+
K = ids_t->numel();
92+
} else if (ids_var->IsType<framework::SelectedRows>()) {
93+
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
94+
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
95+
K = ids_t->rows().size();
96+
output_t->Resize({K, table_t->dims()[1]});
97+
} else {
98+
PADDLE_THROW("Unsupported Variable Type of Ids");
99+
}
80100

81101
size_t N = table_t->dims()[0];
82102
size_t D = table_t->dims()[1];
83-
size_t K = ids_t->numel();
84-
auto* ids = ids_t->data<int64_t>();
85103
auto* table = table_t->data<T>();
86104
auto* output = output_t->mutable_data<T>(context.GetPlace());
87105

paddle/fluid/operators/lookup_table_op.h

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,53 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace operators {
2424

25+
using Tensor = framework::Tensor;
2526
using LoDTensor = framework::LoDTensor;
2627
using SelectedRows = framework::SelectedRows;
2728

2829
template <typename T>
2930
class LookupTableKernel : public framework::OpKernel<T> {
3031
public:
3132
void Compute(const framework::ExecutionContext& context) const override {
32-
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
33-
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
34-
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
33+
auto* table_t = context.Input<LoDTensor>("W");
34+
auto* ids_var = context.InputVar("Ids");
35+
Tensor* output_t = context.Output<Tensor>("Out");
36+
37+
int64_t* ids;
38+
int64_t ids_numel;
39+
40+
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
41+
// is LoDTensor, this tensor contains the ids to be looked up in W;
42+
// when Ids's type is SelectedRows, the rows of Ids contains the
43+
// ids to be looked up in W.
44+
if (ids_var->IsType<LoDTensor>()) {
45+
auto* ids_t = context.Input<LoDTensor>("Ids");
46+
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
47+
ids_numel = ids_t->numel();
48+
} else if (ids_var->IsType<SelectedRows>()) {
49+
auto* ids_t = context.Input<SelectedRows>("Ids");
50+
ids = const_cast<int64_t*>(ids_t->rows().data());
51+
ids_numel = ids_t->rows().size();
52+
output_t->Resize({ids_numel, table_t->dims()[1]});
53+
} else {
54+
PADDLE_THROW("Unsupported Variable Type of Ids");
55+
}
56+
3557
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
3658

3759
int N = table_t->dims()[0];
3860
int D = table_t->dims()[1];
39-
auto* ids = ids_t->data<int64_t>();
4061
auto* table = table_t->data<T>();
4162
auto* output = output_t->mutable_data<T>(context.GetPlace());
4263

4364
if (padding_idx == -1) {
44-
for (int64_t i = 0; i < ids_t->numel(); ++i) {
65+
for (int64_t i = 0; i < ids_numel; ++i) {
4566
PADDLE_ENFORCE_LT(ids[i], N);
4667
PADDLE_ENFORCE_GE(ids[i], 0);
4768
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
4869
}
4970
} else {
50-
for (int64_t i = 0; i < ids_t->numel(); ++i) {
71+
for (int64_t i = 0; i < ids_numel; ++i) {
5172
if (ids[i] == padding_idx) {
5273
memset(output + i * D, 0, D * sizeof(T));
5374
} else {

python/paddle/fluid/tests/unittests/test_lookup_table_op.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import unittest
1616
import numpy as np
1717
from op_test import OpTest
18+
import paddle.fluid.core as core
19+
from paddle.fluid.op import Operator
1820

1921

2022
class TestLookupTableOp(OpTest):
@@ -47,5 +49,52 @@ def test_check_grad(self):
4749
pass
4850

4951

52+
class TestLookupTableIdsIsSelectedRows(OpTest):
53+
def check_with_place(self, place):
54+
scope = core.Scope()
55+
56+
# create and initialize Variable
57+
height = 10
58+
rows = [0, 4, 4, 7]
59+
row_numel = 12
60+
61+
# create and initialize W Variable
62+
W = scope.var('W').get_tensor()
63+
W_array = np.full((height, row_numel), 1.0).astype("float32")
64+
for i in range(height):
65+
W_array[i] *= i
66+
W.set(W_array, place)
67+
68+
# create and initialize Ids Variable
69+
ids_selected_rows = scope.var('Ids').get_selected_rows()
70+
ids_selected_rows.set_height(len(rows))
71+
ids_selected_rows.set_rows(rows)
72+
np_array = np.ones((len(rows), row_numel)).astype("float32")
73+
ids_tensor = ids_selected_rows.get_tensor()
74+
ids_tensor.set(np_array, place)
75+
76+
# create Out Variable
77+
Out = scope.var('Out').get_selected_rows()
78+
79+
# create and run lookup_table operator
80+
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
81+
concat_rows_op.run(scope, place)
82+
83+
# get result from Out
84+
Out_tensor = Out.get_tensor()
85+
result_array = np.array(Out_tensor)
86+
87+
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
88+
for idx, row in enumerate(rows):
89+
assert (row == result_array[idx]).all()
90+
91+
def test_concat_rows(self):
92+
places = [core.CPUPlace()]
93+
if core.is_compiled_with_cuda():
94+
places.append(core.CUDAPlace(0))
95+
for place in places:
96+
self.check_with_place(place)
97+
98+
5099
if __name__ == "__main__":
51100
unittest.main()

0 commit comments

Comments
 (0)