@@ -14,6 +14,9 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < string>
18
+ #include < vector>
19
+
17
20
#include " paddle/fluid/framework/eigen.h"
18
21
#include " paddle/fluid/framework/lod_tensor.h"
19
22
#include " paddle/fluid/framework/op_registry.h"
@@ -25,56 +28,88 @@ namespace operators {
25
28
using Tensor = framework::Tensor;
26
29
using LoDTensor = framework::LoDTensor;
27
30
using SelectedRows = framework::SelectedRows;
31
+ using DDim = framework::DDim;
32
+
33
+ static constexpr int64_t kNoPadding = -1 ;
34
+
35
+ inline size_t getIndex (const std::vector<int64_t > &rows, int64_t value) {
36
+ auto it = std::find (rows.begin (), rows.end (), value);
37
+ PADDLE_ENFORCE (it != rows.end (), " id should be in rows" );
38
+ return static_cast <size_t >(std::distance (rows.begin (), it));
39
+ }
28
40
29
41
template <typename T>
30
42
class LookupTableKernel : public framework ::OpKernel<T> {
31
43
public:
32
- void Compute (const framework::ExecutionContext& context) const override {
33
- auto * table_t = context.Input <LoDTensor>(" W" );
34
- auto * ids_var = context.InputVar (" Ids" );
35
- Tensor* output_t = context.Output <Tensor>(" Out" );
44
+ void Compute (const framework::ExecutionContext &context) const override {
45
+ auto *table_var = context.InputVar (" W" );
46
+ auto *ids_var = context.InputVar (" Ids" );
47
+ Tensor *output_t = context.Output <Tensor>(" Out" );
48
+ int64_t padding_idx = context.Attr <int64_t >(" padding_idx" );
49
+
50
+ DDim table_dim;
36
51
37
- int64_t * ids;
52
+ if (table_var->IsType <LoDTensor>()) {
53
+ table_dim = context.Input <LoDTensor>(" W" )->dims ();
54
+ } else if (table_var->IsType <SelectedRows>()) {
55
+ auto *table_t = context.Input <SelectedRows>(" W" );
56
+ table_dim = table_t ->value ().dims ();
57
+ } else {
58
+ PADDLE_THROW (" table only support LoDTensor and SelectedRows" );
59
+ }
60
+
61
+ int64_t *ids;
38
62
int64_t ids_numel;
39
63
40
64
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
41
65
// is LoDTensor, this tensor contains the ids to be looked up in W;
42
66
// when Ids's type is SelectedRows, the rows of Ids contains the
43
67
// ids to be looked up in W.
44
68
if (ids_var->IsType <LoDTensor>()) {
45
- auto * ids_t = context.Input <LoDTensor>(" Ids" );
46
- ids = const_cast <int64_t *>(ids_t ->data <int64_t >());
69
+ auto * ids_t = context.Input <LoDTensor>(" Ids" );
70
+ ids = const_cast <int64_t *>(ids_t ->data <int64_t >());
47
71
ids_numel = ids_t ->numel ();
48
72
} else if (ids_var->IsType <SelectedRows>()) {
49
- auto * ids_t = context.Input <SelectedRows>(" Ids" );
50
- ids = const_cast <int64_t *>(ids_t ->rows ().data ());
73
+ auto * ids_t = context.Input <SelectedRows>(" Ids" );
74
+ ids = const_cast <int64_t *>(ids_t ->rows ().data ());
51
75
ids_numel = ids_t ->rows ().size ();
52
- output_t ->Resize ({ids_numel, table_t -> dims () [1 ]});
76
+ output_t ->Resize ({ids_numel, table_dim [1 ]});
53
77
} else {
54
78
PADDLE_THROW (" Unsupported Variable Type of Ids" );
55
79
}
56
80
57
- int64_t padding_idx = context.Attr <int64_t >(" padding_idx" );
81
+ if (table_var->IsType <LoDTensor>()) {
82
+ auto *table_t = context.Input <LoDTensor>(" W" );
83
+ int64_t row_number = table_t ->dims ()[0 ];
84
+ int64_t row_width = table_t ->dims ()[1 ];
58
85
59
- int N = table_t ->dims ()[0 ];
60
- int D = table_t ->dims ()[1 ];
61
- auto * table = table_t ->data <T>();
62
- auto * output = output_t ->mutable_data <T>(context.GetPlace ());
86
+ auto *table = table_t ->data <T>();
87
+ auto *output = output_t ->mutable_data <T>(context.GetPlace ());
63
88
64
- if (padding_idx == -1 ) {
65
89
for (int64_t i = 0 ; i < ids_numel; ++i) {
66
- PADDLE_ENFORCE_LT (ids[i], N);
67
- PADDLE_ENFORCE_GE (ids[i], 0 );
68
- memcpy (output + i * D, table + ids[i] * D, D * sizeof (T));
90
+ if (padding_idx != kNoPadding && ids[i] == padding_idx) {
91
+ memset (output + i * row_width, 0 , row_width * sizeof (T));
92
+ } else {
93
+ PADDLE_ENFORCE_LT (ids[i], row_number);
94
+ PADDLE_ENFORCE_GE (ids[i], 0 );
95
+ memcpy (output + i * row_width, table + ids[i] * row_width,
96
+ row_width * sizeof (T));
97
+ }
69
98
}
70
- } else {
99
+ } else if (table_var->IsType <SelectedRows>()) {
100
+ const auto &table_t = table_var->Get <SelectedRows>();
101
+ int64_t row_width = table_t .value ().dims ()[1 ];
102
+ const auto *table = table_t .value ().data <T>();
103
+ auto *output = output_t ->mutable_data <T>(context.GetPlace ());
104
+
71
105
for (int64_t i = 0 ; i < ids_numel; ++i) {
72
- if (ids[i] == padding_idx) {
73
- memset (output + i * D , 0 , D * sizeof (T));
106
+ if (padding_idx != kNoPadding && ids[i] == padding_idx) {
107
+ memset (output + i * row_width , 0 , row_width * sizeof (T));
74
108
} else {
75
- PADDLE_ENFORCE_LT (ids[i], N);
76
109
PADDLE_ENFORCE_GE (ids[i], 0 );
77
- memcpy (output + i * D, table + ids[i] * D, D * sizeof (T));
110
+ auto id_index = getIndex (table_t .rows (), ids[i]);
111
+ memcpy (output + i * row_width, table + id_index * row_width,
112
+ row_width * sizeof (T));
78
113
}
79
114
}
80
115
}
@@ -84,17 +119,27 @@ class LookupTableKernel : public framework::OpKernel<T> {
84
119
template <typename T>
85
120
class LookupTableGradKernel : public framework ::OpKernel<T> {
86
121
public:
87
- void Compute (const framework::ExecutionContext& context) const override {
122
+ void Compute (const framework::ExecutionContext &context) const override {
123
+ auto *table_var = context.InputVar (" W" );
124
+ DDim table_dim;
125
+ if (table_var->IsType <LoDTensor>()) {
126
+ table_dim = context.Input <LoDTensor>(" W" )->dims ();
127
+ } else if (table_var->IsType <SelectedRows>()) {
128
+ auto *table_t = context.Input <SelectedRows>(" W" );
129
+ table_dim = table_t ->value ().dims ();
130
+ } else {
131
+ PADDLE_THROW (" table only support LoDTensor and SelectedRows" );
132
+ }
133
+
88
134
bool is_sparse = context.Attr <bool >(" is_sparse" );
89
135
// Since paddings are not trainable and fixed in forward, the gradient of
90
136
// paddings makes no sense and we don't deal with it in backward.
91
137
if (is_sparse) {
92
- auto * ids = context.Input <LoDTensor>(" Ids" );
93
- auto * table = context.Input <LoDTensor>(" W" );
94
- auto * d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
95
- auto * d_table = context.Output <SelectedRows>(framework::GradVarName (" W" ));
138
+ auto *ids = context.Input <LoDTensor>(" Ids" );
139
+ auto *d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
140
+ auto *d_table = context.Output <SelectedRows>(framework::GradVarName (" W" ));
96
141
97
- auto * ids_data = ids->data <int64_t >();
142
+ auto * ids_data = ids->data <int64_t >();
98
143
auto ids_dim = ids->dims ();
99
144
100
145
framework::Vector<int64_t > new_rows;
@@ -104,31 +149,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
104
149
}
105
150
d_table->set_rows (new_rows);
106
151
107
- auto * d_table_value = d_table->mutable_value ();
108
- d_table_value->Resize ({ids_dim[0 ], table-> dims () [1 ]});
152
+ auto * d_table_value = d_table->mutable_value ();
153
+ d_table_value->Resize ({ids_dim[0 ], table_dim [1 ]});
109
154
d_table_value->mutable_data <T>(context.GetPlace ());
110
155
111
- d_table->set_height (table-> dims () [0 ]);
156
+ d_table->set_height (table_dim [0 ]);
112
157
113
- auto * d_output_data = d_output->data <T>();
114
- auto * d_table_data = d_table_value->data <T>();
158
+ auto * d_output_data = d_output->data <T>();
159
+ auto * d_table_data = d_table_value->data <T>();
115
160
116
161
PADDLE_ENFORCE_EQ (d_table_value->dims (), d_output->dims ());
117
162
memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
118
163
} else {
119
- auto * ids = context.Input <LoDTensor>(" Ids" );
120
- auto * d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
121
- auto * d_table = context.Output <LoDTensor>(framework::GradVarName (" W" ));
122
- auto * table = context.Input <LoDTensor>(" W" );
164
+ auto *ids = context.Input <LoDTensor>(" Ids" );
165
+ auto *d_output = context.Input <LoDTensor>(framework::GradVarName (" Out" ));
166
+ auto *d_table = context.Output <LoDTensor>(framework::GradVarName (" W" ));
123
167
124
- auto * ids_data = ids->data <int64_t >();
168
+ auto * ids_data = ids->data <int64_t >();
125
169
auto ids_dim = ids->dims ();
126
170
127
- int N = table-> dims () [0 ];
171
+ int N = table_dim [0 ];
128
172
int D = d_output->dims ()[1 ];
129
173
130
- auto * d_output_data = d_output->data <T>();
131
- auto * d_table_data = d_table->mutable_data <T>(context.GetPlace ());
174
+ auto * d_output_data = d_output->data <T>();
175
+ auto * d_table_data = d_table->mutable_data <T>(context.GetPlace ());
132
176
133
177
memset (d_table_data, 0 , d_table->numel () * sizeof (T));
134
178
0 commit comments