Skip to content

Commit 1509ce6

Browse files
committed
enhancement look_up_table
1 parent 0d49b92 commit 1509ce6

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ class LookupTableOp : public framework::OperatorWithKernel {
3333
auto table_dims = ctx->GetInputDim("W");
3434
auto ids_dims = ctx->GetInputDim("Ids");
3535

36-
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
37-
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
36+
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
37+
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
38+
// Maybe near future we will add concat_rows op.
39+
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
40+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
41+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
42+
}
3843

3944
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
4045
ctx->ShareLoD("Ids", /*->*/ "Out");

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,34 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
7474
public:
7575
void Compute(const framework::ExecutionContext& context) const override {
7676
auto* table_t = context.Input<LoDTensor>("W");
77-
auto* ids_t = context.Input<LoDTensor>("Ids");
78-
auto* output_t = context.Output<LoDTensor>("Out");
7977
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
78+
auto* ids_var = context.InputVar("Ids"); // int tensor
79+
80+
int64_t* ids;
81+
int64_t K;
82+
framework::Tensor* output_t;
83+
84+
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
85+
// Maybe near future we will add concat_rows op.
86+
if (ids_var->IsType<framework::LoDTensor>()) {
87+
auto* ids_t = context.Input<LoDTensor>("Ids");
88+
output_t = context.Output<LoDTensor>("Out"); // float tensor
89+
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
90+
K = ids_t->numel();
91+
} else if (ids_var->IsType<framework::SelectedRows>()) {
92+
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
93+
output_t = const_cast<framework::Tensor*>(
94+
&(context.Output<framework::SelectedRows>("Out")
95+
->value())); // float tensor
96+
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
97+
K = ids_t->rows().size();
98+
output_t->Resize({K, table_t->dims()[1]});
99+
} else {
100+
PADDLE_THROW("Unsupported Variable Type of Ids");
101+
}
80102

81103
size_t N = table_t->dims()[0];
82104
size_t D = table_t->dims()[1];
83-
size_t K = ids_t->numel();
84-
auto* ids = ids_t->data<int64_t>();
85105
auto* table = table_t->data<T>();
86106
auto* output = output_t->mutable_data<T>(context.GetPlace());
87107

paddle/fluid/operators/lookup_table_op.h

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,54 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace operators {
2424

25+
using Tensor = framework::Tensor;
2526
using LoDTensor = framework::LoDTensor;
2627
using SelectedRows = framework::SelectedRows;
2728

2829
template <typename T>
2930
class LookupTableKernel : public framework::OpKernel<T> {
3031
public:
3132
void Compute(const framework::ExecutionContext& context) const override {
32-
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
33-
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
34-
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
33+
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
34+
auto* ids_var = context.InputVar("Ids"); // int tensor
35+
36+
int64_t* ids;
37+
int64_t ids_numel;
38+
Tensor* output_t;
39+
40+
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
41+
// Maybe near future we will add concat_rows op.
42+
if (ids_var->IsType<LoDTensor>()) {
43+
auto* ids_t = context.Input<LoDTensor>("Ids");
44+
output_t = context.Output<LoDTensor>("Out");
45+
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
46+
ids_numel = ids_t->numel();
47+
} else if (ids_var->IsType<SelectedRows>()) {
48+
auto* ids_t = context.Input<SelectedRows>("Ids");
49+
output_t =
50+
const_cast<Tensor*>(&(context.Output<SelectedRows>("Out")->value()));
51+
ids = const_cast<int64_t*>(ids_t->rows().data());
52+
ids_numel = ids_t->rows().size();
53+
output_t->Resize({ids_numel, table_t->dims()[1]});
54+
} else {
55+
PADDLE_THROW("Unsupported Variable Type of Ids");
56+
}
57+
3558
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
3659

3760
int N = table_t->dims()[0];
3861
int D = table_t->dims()[1];
39-
auto* ids = ids_t->data<int64_t>();
4062
auto* table = table_t->data<T>();
4163
auto* output = output_t->mutable_data<T>(context.GetPlace());
4264

4365
if (padding_idx == -1) {
44-
for (int64_t i = 0; i < ids_t->numel(); ++i) {
66+
for (int64_t i = 0; i < ids_numel; ++i) {
4567
PADDLE_ENFORCE_LT(ids[i], N);
4668
PADDLE_ENFORCE_GE(ids[i], 0);
4769
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
4870
}
4971
} else {
50-
for (int64_t i = 0; i < ids_t->numel(); ++i) {
72+
for (int64_t i = 0; i < ids_numel; ++i) {
5173
if (ids[i] == padding_idx) {
5274
memset(output + i * D, 0, D * sizeof(T));
5375
} else {

0 commit comments

Comments
 (0)