Skip to content

Commit 31e8d80

Browse files
committed
optimize code
1 parent af1d3f5 commit 31e8d80

File tree

4 files changed

+18
-22
lines changed

4 files changed

+18
-22
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace framework {
1919

20-
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value) {
21-
auto it = std::find(rows.begin(), rows.end(), value);
22-
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
23-
return static_cast<size_t>(std::distance(rows.begin(), it));
24-
}
25-
2620
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
2721
const platform::DeviceContext& dev_ctx) {
2822
{ // the 1st field, uint32_t version

paddle/fluid/framework/selected_rows.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ class SelectedRows {
5050

5151
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
5252

53+
/**
54+
* get the index of id in rows
55+
*/
56+
int64_t index(int64_t id) const {
57+
auto it = std::find(rows_.begin(), rows_.end(), id);
58+
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
59+
return static_cast<int64_t>(std::distance(rows_.begin(), it));
60+
}
61+
5362
DDim GetCompleteDims() const {
5463
std::vector<int64_t> dims = vectorize(value_->dims());
5564
dims[0] = height_;
@@ -65,11 +74,6 @@ class SelectedRows {
6574
int64_t height_;
6675
};
6776

68-
/**
69-
* Find the index of value in rows.
70-
*/
71-
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value);
72-
7377
/*
7478
* Serialize/Desiralize SelectedRows to std::ostream
7579
* You can pass ofstream or ostringstream to serilize to file

paddle/fluid/operators/lookup_table_op.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
3030
using SelectedRows = framework::SelectedRows;
3131
using DDim = framework::DDim;
3232

33-
static constexpr int64_t kNoPadding = -1;
34-
35-
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
36-
auto it = std::find(rows.begin(), rows.end(), value);
37-
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
38-
return static_cast<size_t>(std::distance(rows.begin(), it));
39-
}
33+
constexpr int64_t kNoPadding = -1;
4034

4135
template <typename T>
4236
class LookupTableKernel : public framework::OpKernel<T> {
@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
5549
auto *table_t = context.Input<SelectedRows>("W");
5650
table_dim = table_t->value().dims();
5751
} else {
58-
PADDLE_THROW("table only support LoDTensor and SelectedRows");
52+
PADDLE_THROW(
53+
"The parameter W of a LookupTable "
54+
"must be either LoDTensor or SelectedRows");
5955
}
6056

6157
int64_t *ids;
@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
107103
memset(output + i * row_width, 0, row_width * sizeof(T));
108104
} else {
109105
PADDLE_ENFORCE_GE(ids[i], 0);
110-
auto id_index = getIndex(table_t.rows(), ids[i]);
106+
auto id_index = table_t.index(ids[i]);
111107
memcpy(output + i * row_width, table + id_index * row_width,
112108
row_width * sizeof(T));
113109
}
@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
128124
auto *table_t = context.Input<SelectedRows>("W");
129125
table_dim = table_t->value().dims();
130126
} else {
131-
PADDLE_THROW("table only support LoDTensor and SelectedRows");
127+
PADDLE_THROW(
128+
"The parameter W of a LookupTable "
129+
"must be either LoDTensor or SelectedRows");
132130
}
133131

134132
bool is_sparse = context.Attr<bool>("is_sparse");

paddle/fluid/operators/sgd_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
106106
for (size_t i = 0; i < grad.rows().size(); i++) {
107107
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
108108
"Input rows index should less than height");
109-
size_t id_index = framework::GetIndex(param.rows(), grad.rows()[i]);
109+
int64_t id_index = param.index(grad.rows()[i]);
110110
for (int64_t j = 0; j < grad_row_width; j++) {
111111
out_data[id_index * grad_row_width + j] -=
112112
lr[0] * grad_data[i * grad_row_width + j];

0 commit comments

Comments
 (0)