Skip to content

Commit 7efdf05

Browse files
committed
make look_up_op supporting tensor ids
1 parent 56b50ee commit 7efdf05

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
3232

3333
auto table_dims = ctx->GetInputDim("W");
3434
auto ids_dims = ctx->GetInputDim("Ids");
35+
int ids_rank = ids_dims.size();
3536

36-
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
37-
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
37+
PADDLE_ENFORCE_EQ(table_dims.size(), 2);
38+
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
39+
"The last dimension of the 'Ids' tensor must be 1.");
3840

39-
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
41+
auto output_dims =
42+
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
43+
output_dims.push_back(table_dims[1]);
44+
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
4045

4146
if (ctx->GetOutputsVarType("Out")[0] ==
4247
framework::proto::VarType::LOD_TENSOR) {
@@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
6166
AddInput("Ids",
6267
"An input with type int32 or int64 "
6368
"contains the ids to be looked up in W. "
64-
"Ids must be a column vector with rank = 2. "
65-
"The 2nd dimension size must be 1.");
69+
"The last dimension size must be 1.");
6670
AddOutput("Out", "The lookup results, which have the same type as W.");
6771
AddAttr<bool>("is_sparse",
6872
"(boolean, default false) "

paddle/fluid/operators/lookup_table_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,23 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
118118
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
119119

120120
auto *ids_data = ids->data<int64_t>();
121-
auto ids_dim = ids->dims();
121+
int64_t ids_num = ids->numel();
122122

123123
auto stream = dev_ctx.stream();
124124
// copy GPU memory to CPU pinned memory
125125
framework::Vector<int64_t> new_rows;
126-
new_rows.resize(ids_dim[0]);
126+
new_rows.resize(ids_num);
127127
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
128128

129129
// TODO(yuyang18): Strange code here.
130130
memory::Copy(platform::CPUPlace(),
131131
new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
132-
ids_data, ids_dim[0] * sizeof(int64_t), stream);
132+
ids_data, ids_num * sizeof(int64_t), stream);
133133

134134
d_table->set_rows(new_rows);
135135

136136
auto *d_table_value = d_table->mutable_value();
137-
d_table_value->Resize({ids_dim[0], table->dims()[1]});
137+
d_table_value->Resize({ids_num, table->dims()[1]});
138138
d_table_value->mutable_data<T>(context.GetPlace());
139139

140140
auto *d_table_data = d_table_value->data<T>();

paddle/fluid/operators/lookup_table_op.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
109109
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
110110

111111
auto *ids_data = ids->data<int64_t>();
112-
auto ids_dim = ids->dims();
112+
int64_t ids_num = ids->numel();
113113

114114
framework::Vector<int64_t> new_rows;
115-
new_rows.reserve(ids_dim[0]);
116-
for (int64_t i = 0; i < ids_dim[0]; i++) {
115+
new_rows.reserve(ids_num);
116+
for (int64_t i = 0; i < ids_num; i++) {
117117
new_rows.push_back(ids_data[i]);
118118
}
119119
d_table->set_rows(new_rows);
120120

121121
auto *d_table_value = d_table->mutable_value();
122-
d_table_value->Resize({ids_dim[0], table_dim[1]});
122+
d_table_value->Resize({ids_num, table_dim[1]});
123123
d_table_value->mutable_data<T>(context.GetPlace());
124124

125125
d_table->set_height(table_dim[0]);
@@ -135,7 +135,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
135135
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
136136

137137
auto *ids_data = ids->data<int64_t>();
138-
auto ids_dim = ids->dims();
139138

140139
int N = table_dim[0];
141140
int D = d_output->dims()[1];

0 commit comments

Comments
 (0)