Skip to content

Commit 37d9a72

Browse files
authored
Merge pull request #9575 from jacquesqiao/lookup_table_support_SelectedRows_as_parameter
Lookup table support selected rows as parameter
2 parents 172c887 + 13ecb5e commit 37d9a72

File tree

4 files changed

+152
-51
lines changed

4 files changed

+152
-51
lines changed

paddle/fluid/framework/selected_rows.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

1212
#pragma once
13+
14+
#include <vector>
15+
1316
#include "paddle/fluid/framework/lod_tensor.h"
1417
#include "paddle/fluid/framework/tensor.h"
1518

@@ -52,7 +55,7 @@ class SelectedRows {
5255

5356
private:
5457
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
55-
// SelectedRows are simplely concated when adding together. Until a
58+
// SelectedRows are simply concated when adding together. Until a
5659
// SelectedRows add a Tensor, will the duplicate rows be handled.
5760
Vector<int64_t> rows_;
5861
std::unique_ptr<Tensor> value_{nullptr};

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
static inline framework::OpKernelType ExpectedKernelType(
22+
const framework::ExecutionContext& ctx) {
23+
auto* table_var = ctx.InputVar("W");
24+
if (table_var->IsType<LoDTensor>()) {
25+
return framework::OpKernelType(
26+
framework::ToDataType(table_var->Get<LoDTensor>().type()),
27+
ctx.device_context());
28+
} else if (table_var->IsType<SelectedRows>()) {
29+
return framework::OpKernelType(
30+
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
31+
ctx.device_context());
32+
} else {
33+
PADDLE_THROW("W should be LoDTensor or SelectedRows");
34+
}
35+
}
36+
2137
class LookupTableOp : public framework::OperatorWithKernel {
2238
public:
2339
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
5167
protected:
5268
framework::OpKernelType GetExpectedKernelType(
5369
const framework::ExecutionContext& ctx) const override {
54-
return framework::OpKernelType(
55-
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
56-
ctx.device_context());
70+
return ExpectedKernelType(ctx);
5771
}
5872
};
5973

@@ -84,7 +98,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
8498
"If the value is -1, it makes no effect to lookup. "
8599
"Otherwise the given value indicates padding the output "
86100
"with zeros whenever lookup encounters it in Ids.")
87-
.SetDefault(-1);
101+
.SetDefault(kNoPadding);
88102
AddComment(R"DOC(
89103
Lookup Table Operator.
90104
@@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
124138
protected:
125139
framework::OpKernelType GetExpectedKernelType(
126140
const framework::ExecutionContext& ctx) const override {
127-
return framework::OpKernelType(
128-
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
129-
ctx.device_context());
141+
return ExpectedKernelType(ctx);
130142
}
131143
};
132144

paddle/fluid/operators/lookup_table_op.h

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/eigen.h"
1821
#include "paddle/fluid/framework/lod_tensor.h"
1922
#include "paddle/fluid/framework/op_registry.h"
@@ -25,56 +28,88 @@ namespace operators {
2528
using Tensor = framework::Tensor;
2629
using LoDTensor = framework::LoDTensor;
2730
using SelectedRows = framework::SelectedRows;
31+
using DDim = framework::DDim;
32+
33+
static constexpr int64_t kNoPadding = -1;
34+
35+
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
36+
auto it = std::find(rows.begin(), rows.end(), value);
37+
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
38+
return static_cast<size_t>(std::distance(rows.begin(), it));
39+
}
2840

2941
template <typename T>
3042
class LookupTableKernel : public framework::OpKernel<T> {
3143
public:
32-
void Compute(const framework::ExecutionContext& context) const override {
33-
auto* table_t = context.Input<LoDTensor>("W");
34-
auto* ids_var = context.InputVar("Ids");
35-
Tensor* output_t = context.Output<Tensor>("Out");
44+
void Compute(const framework::ExecutionContext &context) const override {
45+
auto *table_var = context.InputVar("W");
46+
auto *ids_var = context.InputVar("Ids");
47+
Tensor *output_t = context.Output<Tensor>("Out");
48+
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
49+
50+
DDim table_dim;
3651

37-
int64_t* ids;
52+
if (table_var->IsType<LoDTensor>()) {
53+
table_dim = context.Input<LoDTensor>("W")->dims();
54+
} else if (table_var->IsType<SelectedRows>()) {
55+
auto *table_t = context.Input<SelectedRows>("W");
56+
table_dim = table_t->value().dims();
57+
} else {
58+
PADDLE_THROW("table only support LoDTensor and SelectedRows");
59+
}
60+
61+
int64_t *ids;
3862
int64_t ids_numel;
3963

4064
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
4165
// is LoDTensor, this tensor contains the ids to be looked up in W;
4266
// when Ids's type is SelectedRows, the rows of Ids contains the
4367
// ids to be looked up in W.
4468
if (ids_var->IsType<LoDTensor>()) {
45-
auto* ids_t = context.Input<LoDTensor>("Ids");
46-
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
69+
auto *ids_t = context.Input<LoDTensor>("Ids");
70+
ids = const_cast<int64_t *>(ids_t->data<int64_t>());
4771
ids_numel = ids_t->numel();
4872
} else if (ids_var->IsType<SelectedRows>()) {
49-
auto* ids_t = context.Input<SelectedRows>("Ids");
50-
ids = const_cast<int64_t*>(ids_t->rows().data());
73+
auto *ids_t = context.Input<SelectedRows>("Ids");
74+
ids = const_cast<int64_t *>(ids_t->rows().data());
5175
ids_numel = ids_t->rows().size();
52-
output_t->Resize({ids_numel, table_t->dims()[1]});
76+
output_t->Resize({ids_numel, table_dim[1]});
5377
} else {
5478
PADDLE_THROW("Unsupported Variable Type of Ids");
5579
}
5680

57-
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
81+
if (table_var->IsType<LoDTensor>()) {
82+
auto *table_t = context.Input<LoDTensor>("W");
83+
int64_t row_number = table_t->dims()[0];
84+
int64_t row_width = table_t->dims()[1];
5885

59-
int N = table_t->dims()[0];
60-
int D = table_t->dims()[1];
61-
auto* table = table_t->data<T>();
62-
auto* output = output_t->mutable_data<T>(context.GetPlace());
86+
auto *table = table_t->data<T>();
87+
auto *output = output_t->mutable_data<T>(context.GetPlace());
6388

64-
if (padding_idx == -1) {
6589
for (int64_t i = 0; i < ids_numel; ++i) {
66-
PADDLE_ENFORCE_LT(ids[i], N);
67-
PADDLE_ENFORCE_GE(ids[i], 0);
68-
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
90+
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
91+
memset(output + i * row_width, 0, row_width * sizeof(T));
92+
} else {
93+
PADDLE_ENFORCE_LT(ids[i], row_number);
94+
PADDLE_ENFORCE_GE(ids[i], 0);
95+
memcpy(output + i * row_width, table + ids[i] * row_width,
96+
row_width * sizeof(T));
97+
}
6998
}
70-
} else {
99+
} else if (table_var->IsType<SelectedRows>()) {
100+
const auto &table_t = table_var->Get<SelectedRows>();
101+
int64_t row_width = table_t.value().dims()[1];
102+
const auto *table = table_t.value().data<T>();
103+
auto *output = output_t->mutable_data<T>(context.GetPlace());
104+
71105
for (int64_t i = 0; i < ids_numel; ++i) {
72-
if (ids[i] == padding_idx) {
73-
memset(output + i * D, 0, D * sizeof(T));
106+
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
107+
memset(output + i * row_width, 0, row_width * sizeof(T));
74108
} else {
75-
PADDLE_ENFORCE_LT(ids[i], N);
76109
PADDLE_ENFORCE_GE(ids[i], 0);
77-
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
110+
auto id_index = getIndex(table_t.rows(), ids[i]);
111+
memcpy(output + i * row_width, table + id_index * row_width,
112+
row_width * sizeof(T));
78113
}
79114
}
80115
}
@@ -84,17 +119,27 @@ class LookupTableKernel : public framework::OpKernel<T> {
84119
template <typename T>
85120
class LookupTableGradKernel : public framework::OpKernel<T> {
86121
public:
87-
void Compute(const framework::ExecutionContext& context) const override {
122+
void Compute(const framework::ExecutionContext &context) const override {
123+
auto *table_var = context.InputVar("W");
124+
DDim table_dim;
125+
if (table_var->IsType<LoDTensor>()) {
126+
table_dim = context.Input<LoDTensor>("W")->dims();
127+
} else if (table_var->IsType<SelectedRows>()) {
128+
auto *table_t = context.Input<SelectedRows>("W");
129+
table_dim = table_t->value().dims();
130+
} else {
131+
PADDLE_THROW("table only support LoDTensor and SelectedRows");
132+
}
133+
88134
bool is_sparse = context.Attr<bool>("is_sparse");
89135
// Since paddings are not trainable and fixed in forward, the gradient of
90136
// paddings makes no sense and we don't deal with it in backward.
91137
if (is_sparse) {
92-
auto* ids = context.Input<LoDTensor>("Ids");
93-
auto* table = context.Input<LoDTensor>("W");
94-
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
95-
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
138+
auto *ids = context.Input<LoDTensor>("Ids");
139+
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
140+
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
96141

97-
auto* ids_data = ids->data<int64_t>();
142+
auto *ids_data = ids->data<int64_t>();
98143
auto ids_dim = ids->dims();
99144

100145
framework::Vector<int64_t> new_rows;
@@ -104,31 +149,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
104149
}
105150
d_table->set_rows(new_rows);
106151

107-
auto* d_table_value = d_table->mutable_value();
108-
d_table_value->Resize({ids_dim[0], table->dims()[1]});
152+
auto *d_table_value = d_table->mutable_value();
153+
d_table_value->Resize({ids_dim[0], table_dim[1]});
109154
d_table_value->mutable_data<T>(context.GetPlace());
110155

111-
d_table->set_height(table->dims()[0]);
156+
d_table->set_height(table_dim[0]);
112157

113-
auto* d_output_data = d_output->data<T>();
114-
auto* d_table_data = d_table_value->data<T>();
158+
auto *d_output_data = d_output->data<T>();
159+
auto *d_table_data = d_table_value->data<T>();
115160

116161
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
117162
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
118163
} else {
119-
auto* ids = context.Input<LoDTensor>("Ids");
120-
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
121-
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
122-
auto* table = context.Input<LoDTensor>("W");
164+
auto *ids = context.Input<LoDTensor>("Ids");
165+
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
166+
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
123167

124-
auto* ids_data = ids->data<int64_t>();
168+
auto *ids_data = ids->data<int64_t>();
125169
auto ids_dim = ids->dims();
126170

127-
int N = table->dims()[0];
171+
int N = table_dim[0];
128172
int D = d_output->dims()[1];
129173

130-
auto* d_output_data = d_output->data<T>();
131-
auto* d_table_data = d_table->mutable_data<T>(context.GetPlace());
174+
auto *d_output_data = d_output->data<T>();
175+
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
132176

133177
memset(d_table_data, 0, d_table->numel() * sizeof(T));
134178

python/paddle/fluid/tests/unittests/test_lookup_table_op.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,47 @@ def test_concat_rows(self):
9696
self.check_with_place(place)
9797

9898

99+
class TestLookupTableWIsSelectedRows(OpTest):
100+
def check_with_place(self, place):
101+
scope = core.Scope()
102+
103+
# create and initialize Id Variable
104+
ids_tensor = scope.var('Ids').get_tensor()
105+
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
106+
ids_tensor.set(ids_array, place)
107+
108+
# create and initialize W Variable
109+
rows = [0, 1, 2, 3, 4, 5, 6]
110+
row_numel = 12
111+
112+
w_selected_rows = scope.var('W').get_selected_rows()
113+
w_selected_rows.set_height(len(rows))
114+
w_selected_rows.set_rows(rows)
115+
w_array = np.ones((len(rows), row_numel)).astype("float32")
116+
for i in range(len(rows)):
117+
w_array[i] *= i
118+
ids_tensor = w_selected_rows.get_tensor()
119+
ids_tensor.set(w_array, place)
120+
121+
# create Out Variable
122+
Out_tensor = scope.var('Out').get_tensor()
123+
124+
# create and run lookup_table operator
125+
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
126+
lookup_table.run(scope, place)
127+
128+
# get result from Out
129+
result_array = np.array(Out_tensor)
130+
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
131+
for idx, row in enumerate(ids_array):
132+
assert (row[0] == result_array[idx]).all()
133+
134+
def test_w_is_selected_rows(self):
135+
places = [core.CPUPlace()]
136+
# currently only support CPU
137+
for place in places:
138+
self.check_with_place(place)
139+
140+
99141
if __name__ == "__main__":
100142
unittest.main()

0 commit comments

Comments
 (0)