Skip to content

Commit 424dd2f

Browse files
authored
Merge pull request #9597 from jacquesqiao/sgd-support-update-selected-rows
Sgd support update selected rows
2 parents 7bf82f8 + ff4208e commit 424dd2f

File tree

10 files changed

+201
-83
lines changed

10 files changed

+201
-83
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
3535
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
3636
};
3737

38+
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
39+
if (var->IsType<framework::LoDTensor>()) {
40+
return framework::ToDataType(var->Get<framework::LoDTensor>().type());
41+
} else if (var->IsType<framework::SelectedRows>()) {
42+
return framework::ToDataType(
43+
var->Get<framework::SelectedRows>().value().type());
44+
} else {
45+
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
46+
}
47+
}
48+
3849
static DDim GetDims(const Scope& scope, const std::string& name) {
3950
Variable* var = scope.FindVar(name);
4051
if (var == nullptr) {

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) {
6161
return var_name + kGradVarSuffix;
6262
}
6363

64+
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
65+
6466
class OperatorBase;
6567
class ExecutionContext;
6668

paddle/fluid/framework/selected_rows.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
6+
57
http://www.apache.org/licenses/LICENSE-2.0
8+
69
Unless required by applicable law or agreed to in writing, software
710
distributed under the License is distributed on an "AS IS" BASIS,
811
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -13,6 +16,7 @@ limitations under the License. */
1316

1417
namespace paddle {
1518
namespace framework {
19+
1620
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
1721
const platform::DeviceContext& dev_ctx) {
1822
{ // the 1st field, uint32_t version

paddle/fluid/framework/selected_rows.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
6+
57
http://www.apache.org/licenses/LICENSE-2.0
8+
69
Unless required by applicable law or agreed to in writing, software
710
distributed under the License is distributed on an "AS IS" BASIS,
811
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -47,6 +50,15 @@ class SelectedRows {
4750

4851
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
4952

53+
/**
54+
* get the index of id in rows
55+
*/
56+
int64_t index(int64_t id) const {
57+
auto it = std::find(rows_.begin(), rows_.end(), id);
58+
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
59+
return static_cast<int64_t>(std::distance(rows_.begin(), it));
60+
}
61+
5062
DDim GetCompleteDims() const {
5163
std::vector<int64_t> dims = vectorize(value_->dims());
5264
dims[0] = height_;

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,6 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21-
static inline framework::OpKernelType ExpectedKernelType(
22-
const framework::ExecutionContext& ctx) {
23-
auto* table_var = ctx.InputVar("W");
24-
if (table_var->IsType<LoDTensor>()) {
25-
return framework::OpKernelType(
26-
framework::ToDataType(table_var->Get<LoDTensor>().type()),
27-
ctx.device_context());
28-
} else if (table_var->IsType<SelectedRows>()) {
29-
return framework::OpKernelType(
30-
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
31-
ctx.device_context());
32-
} else {
33-
PADDLE_THROW("W should be LoDTensor or SelectedRows");
34-
}
35-
}
36-
3721
class LookupTableOp : public framework::OperatorWithKernel {
3822
public:
3923
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel {
6751
protected:
6852
framework::OpKernelType GetExpectedKernelType(
6953
const framework::ExecutionContext& ctx) const override {
70-
return ExpectedKernelType(ctx);
54+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
55+
return framework::OpKernelType(data_type, ctx.device_context());
7156
}
7257
};
7358

@@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
138123
protected:
139124
framework::OpKernelType GetExpectedKernelType(
140125
const framework::ExecutionContext& ctx) const override {
141-
return ExpectedKernelType(ctx);
126+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
127+
return framework::OpKernelType(data_type, ctx.device_context());
142128
}
143129
};
144130

paddle/fluid/operators/lookup_table_op.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
3030
using SelectedRows = framework::SelectedRows;
3131
using DDim = framework::DDim;
3232

33-
static constexpr int64_t kNoPadding = -1;
34-
35-
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
36-
auto it = std::find(rows.begin(), rows.end(), value);
37-
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
38-
return static_cast<size_t>(std::distance(rows.begin(), it));
39-
}
33+
constexpr int64_t kNoPadding = -1;
4034

4135
template <typename T>
4236
class LookupTableKernel : public framework::OpKernel<T> {
@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
5549
auto *table_t = context.Input<SelectedRows>("W");
5650
table_dim = table_t->value().dims();
5751
} else {
58-
PADDLE_THROW("table only support LoDTensor and SelectedRows");
52+
PADDLE_THROW(
53+
"The parameter W of a LookupTable "
54+
"must be either LoDTensor or SelectedRows");
5955
}
6056

6157
int64_t *ids;
@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
107103
memset(output + i * row_width, 0, row_width * sizeof(T));
108104
} else {
109105
PADDLE_ENFORCE_GE(ids[i], 0);
110-
auto id_index = getIndex(table_t.rows(), ids[i]);
106+
auto id_index = table_t.index(ids[i]);
111107
memcpy(output + i * row_width, table + id_index * row_width,
112108
row_width * sizeof(T));
113109
}
@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
128124
auto *table_t = context.Input<SelectedRows>("W");
129125
table_dim = table_t->value().dims();
130126
} else {
131-
PADDLE_THROW("table only support LoDTensor and SelectedRows");
127+
PADDLE_THROW(
128+
"The parameter W of a LookupTable "
129+
"must be either LoDTensor or SelectedRows");
132130
}
133131

134132
bool is_sparse = context.Attr<bool>("is_sparse");

paddle/fluid/operators/sgd_op.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,21 @@ class SGDOp : public framework::OperatorWithKernel {
4343
protected:
4444
framework::OpKernelType GetExpectedKernelType(
4545
const framework::ExecutionContext& ctx) const override {
46-
return framework::OpKernelType(
47-
framework::ToDataType(ctx.Input<framework::LoDTensor>("Param")->type()),
48-
ctx.GetPlace());
46+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
47+
return framework::OpKernelType(data_type, ctx.device_context());
4948
}
5049
};
5150

5251
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
5352
public:
5453
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
5554
: OpProtoAndCheckerMaker(proto, op_checker) {
56-
AddInput("Param", "(Tensor) Input parameter");
55+
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
5756
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
58-
AddInput("Grad", "(Tensor) Input gradient");
59-
AddOutput("ParamOut", "(Tensor) Output parameter");
57+
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
58+
AddOutput("ParamOut",
59+
"(Tensor or SelectedRows, same with Param) "
60+
"Output parameter, should share the same memory with Param");
6061
AddComment(R"DOC(
6162
6263
SGD operator

paddle/fluid/operators/sgd_op.h

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,60 +23,97 @@ namespace operators {
2323
template <typename T>
2424
class SGDOpKernel : public framework::OpKernel<T> {
2525
public:
26-
void Compute(const framework::ExecutionContext& ctx) const override {
27-
auto* param = ctx.Input<framework::Tensor>("Param");
28-
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
29-
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
30-
31-
auto* grad_var = ctx.InputVar("Grad");
32-
// Actually, all tensors are LoDTensor except SelectedRows.
33-
if (grad_var->IsType<framework::LoDTensor>()) {
34-
param_out->mutable_data<T>(ctx.GetPlace());
35-
auto* grad = ctx.Input<framework::Tensor>("Grad");
36-
37-
auto p = framework::EigenVector<T>::Flatten(*param);
38-
auto g = framework::EigenVector<T>::Flatten(*grad);
39-
auto o = framework::EigenVector<T>::Flatten(*param_out);
40-
auto* lr = learning_rate->data<T>();
41-
42-
o = p - lr[0] * g;
43-
} else if (grad_var->IsType<framework::SelectedRows>()) {
44-
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
45-
// This manual optimization brings difficulty to track data dependency.
46-
// It's better to find a more elegant solution.
47-
PADDLE_ENFORCE_EQ(param, param_out);
48-
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
26+
void Compute(const framework::ExecutionContext &ctx) const override {
27+
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
28+
29+
const auto *param_var = ctx.InputVar("Param");
30+
const auto *grad_var = ctx.InputVar("Grad");
31+
32+
if (param_var->IsType<framework::LoDTensor>()) {
33+
const auto *param = ctx.Input<framework::Tensor>("Param");
34+
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
35+
36+
// Actually, all tensors are LoDTensor except SelectedRows.
37+
if (grad_var->IsType<framework::LoDTensor>()) {
38+
param_out->mutable_data<T>(ctx.GetPlace());
39+
const auto *grad = ctx.Input<framework::Tensor>("Grad");
40+
41+
auto p = framework::EigenVector<T>::Flatten(*param);
42+
auto g = framework::EigenVector<T>::Flatten(*grad);
43+
auto o = framework::EigenVector<T>::Flatten(*param_out);
44+
auto *lr = learning_rate->data<T>();
45+
46+
o = p - lr[0] * g;
47+
} else if (grad_var->IsType<framework::SelectedRows>()) {
48+
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
49+
// This manual optimization brings difficulty to track data dependency.
50+
// It's better to find a more elegant solution.
51+
PADDLE_ENFORCE_EQ(param, param_out);
52+
const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
53+
54+
// for distributed training, a sparse var may be empty,
55+
// just skip updating.
56+
if (grad->rows().size() == 0) {
57+
return;
58+
}
59+
60+
auto grad_height = grad->height();
61+
auto out_dims = param_out->dims();
62+
PADDLE_ENFORCE_EQ(grad_height, out_dims[0]);
63+
64+
auto &grad_value = grad->value();
65+
auto &grad_rows = grad->rows();
66+
67+
size_t grad_row_numel = grad_value.numel() / grad_rows.size();
68+
PADDLE_ENFORCE_EQ(grad_row_numel, param_out->numel() / grad_height);
69+
70+
auto *grad_data = grad_value.data<T>();
71+
auto *out_data = param_out->data<T>();
72+
auto *lr = learning_rate->data<T>();
73+
for (size_t i = 0; i < grad_rows.size(); i++) {
74+
PADDLE_ENFORCE(grad_rows[i] < grad_height,
75+
"Input rows index should less than height");
76+
for (int64_t j = 0; j < grad_row_numel; j++) {
77+
out_data[grad_rows[i] * grad_row_numel + j] -=
78+
lr[0] * grad_data[i * grad_row_numel + j];
79+
}
80+
}
81+
} else {
82+
PADDLE_THROW("Unsupported Variable Type of Grad");
83+
}
84+
} else if (param_var->IsType<framework::SelectedRows>()) {
85+
PADDLE_ENFORCE(grad_var->IsType<framework::SelectedRows>(),
86+
"when param "
87+
"is SelectedRows, gradient should also be SelectedRows");
88+
const auto &param = param_var->Get<framework::SelectedRows>();
89+
auto *param_out = ctx.Output<framework::SelectedRows>("ParamOut");
90+
const auto &grad = grad_var->Get<framework::SelectedRows>();
4991

5092
// for distributed training, a sparse var may be empty,
5193
// just skip updating.
52-
if (grad->rows().size() == 0) {
94+
if (grad.rows().size() == 0) {
5395
return;
5496
}
5597

56-
auto in_height = grad->height();
57-
auto out_dims = param_out->dims();
58-
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
59-
60-
auto& in_value = grad->value();
61-
auto& in_rows = grad->rows();
98+
size_t param_row_width = param.value().numel() / param.rows().size();
99+
size_t grad_row_width = grad.value().numel() / grad.rows().size();
100+
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
101+
"param_row should have the same size with grad_row");
62102

63-
int64_t in_row_numel = in_value.numel() / in_rows.size();
64-
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
65-
66-
auto* in_data = in_value.data<T>();
67-
auto* out_data = param_out->data<T>();
68-
auto* lr = learning_rate->data<T>();
69-
for (size_t i = 0; i < in_rows.size(); i++) {
70-
PADDLE_ENFORCE(in_rows[i] < in_height,
103+
const auto *lr = learning_rate->data<T>();
104+
const auto *grad_data = grad.value().data<T>();
105+
auto *out_data = param_out->mutable_value()->data<T>();
106+
for (size_t i = 0; i < grad.rows().size(); i++) {
107+
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
71108
"Input rows index should less than height");
72-
for (int64_t j = 0; j < in_row_numel; j++) {
73-
out_data[in_rows[i] * in_row_numel + j] -=
74-
lr[0] * in_data[i * in_row_numel + j];
109+
int64_t id_index = param.index(grad.rows()[i]);
110+
for (int64_t j = 0; j < grad_row_width; j++) {
111+
out_data[id_index * grad_row_width + j] -=
112+
lr[0] * grad_data[i * grad_row_width + j];
75113
}
76114
}
77-
78115
} else {
79-
PADDLE_THROW("Unsupported Variable Type of Grad");
116+
PADDLE_THROW("Unsupported Variable Type of Parameter");
80117
}
81118
}
82119
};

python/paddle/fluid/tests/unittests/test_lookup_table_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,18 @@ def check_with_place(self, place):
115115
w_array = np.ones((len(rows), row_numel)).astype("float32")
116116
for i in range(len(rows)):
117117
w_array[i] *= i
118-
ids_tensor = w_selected_rows.get_tensor()
119-
ids_tensor.set(w_array, place)
118+
w_tensor = w_selected_rows.get_tensor()
119+
w_tensor.set(w_array, place)
120120

121121
# create Out Variable
122-
Out_tensor = scope.var('Out').get_tensor()
122+
out_tensor = scope.var('Out').get_tensor()
123123

124124
# create and run lookup_table operator
125125
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
126126
lookup_table.run(scope, place)
127127

128128
# get result from Out
129-
result_array = np.array(Out_tensor)
129+
result_array = np.array(out_tensor)
130130
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
131131
for idx, row in enumerate(ids_array):
132132
assert (row[0] == result_array[idx]).all()

0 commit comments

Comments
 (0)