Skip to content

Commit 8468037

Browse files
author
Yancey
authored
Fix sparse update memory error for distributed training (#8837)
Fix sparse update memory error for distributed training
1 parent 124b750 commit 8468037

File tree

7 files changed

+72
-27
lines changed

7 files changed

+72
-27
lines changed

paddle/fluid/operators/send_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ limitations under the License. */
2424

2525
namespace paddle {
2626
namespace operators {
27-
static bool IsVariableInitialized(const framework::Scope& scope,
28-
const std::string& varname) {
27+
static bool NeedSend(const framework::Scope& scope,
28+
const std::string& varname) {
2929
auto* var = scope.FindVar(varname);
3030
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
3131
varname);
3232
if (var->IsType<framework::LoDTensor>()) {
3333
return var->Get<framework::LoDTensor>().IsInitialized();
3434
} else if (var->IsType<framework::SelectedRows>()) {
35-
return var->Get<framework::SelectedRows>().value().IsInitialized();
35+
return var->Get<framework::SelectedRows>().rows().size() > 0UL;
3636
} else {
3737
PADDLE_THROW(
3838
"Variable type in send side should be in "
@@ -67,7 +67,7 @@ class SendOp : public framework::OperatorBase {
6767
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
6868

6969
for (size_t i = 0; i < ins.size(); i++) {
70-
if (IsVariableInitialized(scope, ins[i])) {
70+
if (NeedSend(scope, ins[i])) {
7171
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
7272
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
7373
} else {

paddle/fluid/operators/sgd_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ class SGDOp : public framework::OperatorWithKernel {
3939
// and run time.
4040
ctx->SetOutputDim("ParamOut", param_dim);
4141
}
42+
43+
protected:
44+
framework::OpKernelType GetExpectedKernelType(
45+
const framework::ExecutionContext& ctx) const override {
46+
return framework::OpKernelType(
47+
framework::ToDataType(ctx.Input<framework::LoDTensor>("Param")->type()),
48+
ctx.GetPlace());
49+
}
4250
};
4351

4452
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/sgd_op.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
4747
PADDLE_ENFORCE_EQ(param, param_out);
4848
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
4949

50+
// for distributed training, a sparse var may be empty,
51+
// just skip updating.
52+
if (grad->rows().size() == 0) {
53+
return;
54+
}
55+
5056
auto in_height = grad->height();
5157
auto out_dims = param_out->dims();
5258
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
@@ -60,13 +66,15 @@ class SGDOpKernel : public framework::OpKernel<T> {
6066
auto* in_data = in_value.data<T>();
6167
auto* out_data = param_out->data<T>();
6268
auto* lr = learning_rate->data<T>();
63-
6469
for (size_t i = 0; i < in_rows.size(); i++) {
70+
PADDLE_ENFORCE(in_rows[i] < in_height,
71+
"Input rows index should less than height");
6572
for (int64_t j = 0; j < in_row_numel; j++) {
6673
out_data[in_rows[i] * in_row_numel + j] -=
6774
lr[0] * in_data[i * in_row_numel + j];
6875
}
6976
}
77+
7078
} else {
7179
PADDLE_THROW("Unsupported Variable Type of Grad");
7280
}

paddle/fluid/operators/split_selected_rows_op.h

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,24 @@ limitations under the License. */
2121
namespace paddle {
2222
namespace operators {
2323

24-
static int FindOutIdx(int row, const std::vector<int>& height_sections) {
25-
int offset = 0;
26-
for (size_t i = 0; i < height_sections.size(); ++i) {
27-
if (row >= offset && row < (offset + height_sections[i])) {
28-
return i;
24+
static int FindOutIdx(int row, const std::vector<int>& abs_sections) {
25+
for (size_t i = 1; i < abs_sections.size(); ++i) {
26+
if (row < abs_sections[i]) {
27+
return i - 1;
2928
}
30-
offset += height_sections[i];
3129
}
32-
return -1;
30+
return abs_sections.size() - 1;
31+
}
32+
33+
static std::vector<int> ToAbsoluteSection(
34+
const std::vector<int>& height_sections) {
35+
std::vector<int> abs_sections;
36+
abs_sections.resize(height_sections.size());
37+
abs_sections[0] = 0;
38+
for (size_t i = 1; i < height_sections.size(); ++i) {
39+
abs_sections[i] = height_sections[i - 1] + abs_sections[i - 1];
40+
}
41+
return abs_sections;
3342
}
3443

3544
template <typename DeviceContext, typename T>
@@ -40,16 +49,23 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
4049
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
4150
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
4251

52+
auto abs_sections = ToAbsoluteSection(height_sections);
53+
4354
auto x_rows = x->rows();
4455
std::vector<std::vector<int>> outs_rows_idx;
56+
std::vector<std::vector<int>> outs_dense_idx;
57+
4558
outs_rows_idx.resize(outs.size());
59+
outs_dense_idx.resize(outs.size());
4660

4761
auto row_numel = x->value().numel() / x->value().dims()[0];
4862
auto src = x->value().data<T>();
4963

64+
// split rows index into output sparse vars
5065
for (size_t i = 0; i < x_rows.size(); ++i) {
51-
int out_idx = FindOutIdx(x_rows[i], height_sections);
52-
outs_rows_idx[out_idx].push_back(i);
66+
int out_idx = FindOutIdx(x_rows[i], abs_sections);
67+
outs_rows_idx[out_idx].push_back(x_rows[i]);
68+
outs_dense_idx[out_idx].push_back(i);
5369
}
5470
auto place = ctx.GetPlace();
5571

@@ -61,19 +77,20 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
6177
dims[0] = rows_idx.size();
6278
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
6379
for (auto idx : rows_idx) {
64-
outs[i]->mutable_rows()->push_back(x_rows[idx]);
80+
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
6581
}
6682
auto dst = outs[i]->mutable_value()->mutable_data<T>(ctx.GetPlace());
6783
for (size_t j = 0; j < rows_idx.size(); j++) {
6884
if (platform::is_cpu_place(place)) {
69-
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
70-
platform::CPUPlace(), src + rows_idx[j] * row_numel,
71-
sizeof(T) * row_numel);
85+
memory::Copy(
86+
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
87+
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
7288
} else {
7389
#ifdef PADDLE_WITH_CUDA
7490
auto stream = ctx.cuda_device_context().stream();
7591
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
76-
platform::CUDAPlace(), src + rows_idx[j] * row_numel,
92+
platform::CUDAPlace(),
93+
src + outs_dense_idx[i][j] * row_numel,
7794
sizeof(T) * row_numel, stream);
7895
#else
7996
PADDLE_THROW("Paddle is not compiled with GPU");

paddle/fluid/operators/sum_op.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ class SumOp : public framework::OperatorWithKernel {
7676
static_cast<framework::proto::VarType::Type>(dtype),
7777
ctx.device_context());
7878
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
79-
return framework::OpKernelType(
80-
framework::ToDataType(
81-
x_vars[0]->Get<framework::SelectedRows>().value().type()),
82-
ctx.device_context());
79+
for (auto& var : x_vars) {
80+
auto& value = var->Get<framework::SelectedRows>().value();
81+
if (value.IsInitialized()) {
82+
return framework::OpKernelType(framework::ToDataType(value.type()),
83+
ctx.device_context());
84+
}
85+
}
86+
// if input sparse vars are not initialized, use an default kernel type.
87+
return framework::OpKernelType(framework::proto::VarType::FP32,
88+
ctx.device_context());
8389
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
8490
for (auto& x_var : x_vars) {
8591
auto& array = x_var->Get<framework::LoDTensorArray>();

paddle/fluid/operators/sum_op.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,20 @@ class SumKernel : public framework::OpKernel<T> {
109109
in_dim[0] = static_cast<int64_t>(first_dim);
110110

111111
out_value->Resize(framework::make_ddim(in_dim));
112+
113+
// if all the input sparse vars are empty, no need to
114+
// merge these vars.
115+
if (first_dim == 0UL) {
116+
return;
117+
}
112118
out_value->mutable_data<T>(context.GetPlace());
113119

114120
math::SelectedRowsAddTo<DeviceContext, T> functor;
115121

116122
int64_t offset = 0;
117123
for (int i = 0; i < N; i++) {
118124
auto &sel_row = get_selected_row(i);
119-
if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
125+
if (sel_row.rows().size() == 0) {
120126
continue;
121127
}
122128
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());

python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def check_with_place(self, place):
6060

6161
# expected output selected rows
6262
expected_out0_rows = [0, 4]
63-
expected_out1_rows = [5, 7]
64-
expected_out4_rows = [20]
63+
expected_out1_rows = [0, 2]
64+
expected_out4_rows = [0]
6565

6666
op = Operator(
6767
"split_selected_rows",
@@ -101,7 +101,7 @@ def check_grad_with_place(self, place):
101101
out0_grad_tensor.set(np_array, place)
102102

103103
out1_grad = scope.var("out1@GRAD").get_selected_rows()
104-
rows1 = [7, 5]
104+
rows1 = [2, 0]
105105
out1_grad.set_rows(rows1)
106106
out1_grad.set_height(height)
107107
out1_grad_tensor = out1_grad.get_tensor()

0 commit comments

Comments
 (0)