Skip to content

Commit 9074a60

Browse files
authored
Refine lookup_table_op (#5257)
1. Change some `auto` to `auto*` 2. Change `Tensor` to `LoDTensor`
1 parent db3b943 commit 9074a60

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

paddle/operators/lookup_table_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
4343
protected:
4444
framework::DataType IndicateDataType(
4545
const framework::ExecutionContext& ctx) const override {
46-
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
46+
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
4747
}
4848
};
4949

@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
9393
protected:
9494
framework::DataType IndicateDataType(
9595
const framework::ExecutionContext& ctx) const override {
96-
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
96+
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
9797
}
9898
};
9999

paddle/operators/lookup_table_op.cu

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ template <typename T>
6161
class LookupTableCUDAKernel : public framework::OpKernel<T> {
6262
public:
6363
void Compute(const framework::ExecutionContext& context) const override {
64-
auto table_t = context.Input<Tensor>("W");
65-
auto ids_t = context.Input<Tensor>("Ids");
66-
auto output_t = context.Output<Tensor>("Out");
64+
auto* table_t = context.Input<LoDTensor>("W");
65+
auto* ids_t = context.Input<LoDTensor>("Ids");
66+
auto* output_t = context.Output<LoDTensor>("Out");
6767

6868
size_t N = table_t->dims()[0];
6969
size_t D = table_t->dims()[1];
7070
size_t K = ids_t->numel();
71-
auto ids = ids_t->data<int64_t>();
72-
auto table = table_t->data<T>();
73-
auto output = output_t->mutable_data<T>(context.GetPlace());
71+
auto* ids = ids_t->data<int64_t>();
72+
auto* table = table_t->data<T>();
73+
auto* output = output_t->mutable_data<T>(context.GetPlace());
7474

7575
dim3 threads(128, 8);
7676
dim3 grids(8, 1);
@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
8787
void Compute(const framework::ExecutionContext& context) const override {
8888
bool is_sparse = context.Attr<bool>("is_sparse");
8989
if (is_sparse) {
90-
auto* ids = context.Input<Tensor>("Ids");
91-
auto* table = context.Input<Tensor>("W");
92-
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
90+
auto* ids = context.Input<LoDTensor>("Ids");
91+
auto* table = context.Input<LoDTensor>("W");
92+
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
9393
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
9494

9595
auto* ids_data = ids->data<int64_t>();
@@ -119,9 +119,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
119119
d_output->numel(), stream);
120120

121121
} else {
122-
auto ids_t = context.Input<Tensor>("Ids");
123-
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
124-
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
122+
auto ids_t = context.Input<LoDTensor>("Ids");
123+
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
124+
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
125125

126126
int N = d_table_t->dims()[0];
127127
int D = d_table_t->dims()[1];

paddle/operators/lookup_table_op.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@
1919
namespace paddle {
2020
namespace operators {
2121

22-
using Tensor = framework::Tensor;
22+
using LoDTensor = framework::LoDTensor;
2323
using SelectedRows = framework::SelectedRows;
2424

2525
template <typename T>
2626
class LookupTableKernel : public framework::OpKernel<T> {
2727
public:
2828
void Compute(const framework::ExecutionContext& context) const override {
29-
auto table_t = context.Input<Tensor>("W"); // float tensor
30-
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
31-
auto output_t = context.Output<Tensor>("Out"); // float tensor
29+
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
30+
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
31+
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
3232

3333
int N = table_t->dims()[0];
3434
int D = table_t->dims()[1];
35-
auto ids = ids_t->data<int64_t>();
36-
auto table = table_t->data<T>();
37-
auto output = output_t->mutable_data<T>(context.GetPlace());
35+
auto* ids = ids_t->data<int64_t>();
36+
auto* table = table_t->data<T>();
37+
auto* output = output_t->mutable_data<T>(context.GetPlace());
3838
for (int64_t i = 0; i < ids_t->numel(); ++i) {
3939
PADDLE_ENFORCE_LT(ids[i], N);
4040
PADDLE_ENFORCE_GE(ids[i], 0);
@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
4949
void Compute(const framework::ExecutionContext& context) const override {
5050
bool is_sparse = context.Attr<bool>("is_sparse");
5151
if (is_sparse) {
52-
auto* ids = context.Input<Tensor>("Ids");
53-
auto* table = context.Input<Tensor>("W");
54-
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
52+
auto* ids = context.Input<LoDTensor>("Ids");
53+
auto* table = context.Input<LoDTensor>("W");
54+
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
5555
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
5656

5757
auto* ids_data = ids->data<int64_t>();
@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
7676
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
7777
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
7878
} else {
79-
auto* ids = context.Input<Tensor>("Ids");
80-
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
81-
auto* d_table = context.Output<Tensor>(framework::GradVarName("W"));
82-
auto* table = context.Input<Tensor>("W");
79+
auto* ids = context.Input<LoDTensor>("Ids");
80+
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
81+
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
82+
auto* table = context.Input<LoDTensor>("W");
8383

8484
auto* ids_data = ids->data<int64_t>();
8585
auto ids_dim = ids->dims();

0 commit comments

Comments
 (0)