Skip to content

Commit caf9a09

Browse files
author
Yancey
authored
Merge selected rows with dynamic variable count (#8023)
* dynamic send/recv selected rows * update by comment * fix by comment
1 parent 4f4abfa commit caf9a09

File tree

6 files changed

+47
-25
lines changed

6 files changed

+47
-25
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class ListenAndServOp : public framework::OperatorBase {
101101

102102
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
103103
bool exit_flag = false;
104+
// Record received sparse variables, so that
105+
// we could reset those after execute optimize program
106+
std::vector<framework::Variable *> sparse_vars;
104107
while (!exit_flag) {
105108
// Get from multiple trainers, we don't care about the order in which
106109
// the gradients arrives, just add suffix 0~n and merge the gradient.
@@ -143,6 +146,9 @@ class ListenAndServOp : public framework::OperatorBase {
143146
PADDLE_THROW("Can not find server side var");
144147
}
145148
detail::DeserializeFromMessage(v.second, dev_ctx, var);
149+
if (var->IsType<framework::SelectedRows>()) {
150+
sparse_vars.push_back(var);
151+
}
146152
}
147153
}
148154
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
@@ -156,9 +162,19 @@ class ListenAndServOp : public framework::OperatorBase {
156162
} catch (std::exception &e) {
157163
LOG(ERROR) << "run sub program error " << e.what();
158164
}
165+
166+
// Reset the received sparse variables, the sum operator would not
167+
// sum the input sparse variables which rows is empty at the next
168+
// mini-batch.
169+
// TOOD(Yancey1989): move the reset action into an operator, we couldn't
170+
// have any hide logic in the operator.
171+
for (auto &var : sparse_vars) {
172+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
173+
}
159174
rpc_service_->SetCond(1);
160175
rpc_service_->WaitClientGet(update_param_cnt);
161176
grads_counter_.clear();
177+
sparse_vars.clear();
162178
} // while(true)
163179
}
164180

paddle/fluid/operators/send_op.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ limitations under the License. */
2424

2525
namespace paddle {
2626
namespace operators {
27+
static bool IsVariableInitialized(const framework::Scope& scope,
28+
const std::string& varname) {
29+
auto* var = scope.FindVar(varname);
30+
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
31+
varname);
32+
if (var->IsType<framework::LoDTensor>()) {
33+
return var->Get<framework::LoDTensor>().IsInitialized();
34+
} else if (var->IsType<framework::SelectedRows>()) {
35+
return var->Get<framework::SelectedRows>().value().IsInitialized();
36+
} else {
37+
PADDLE_THROW(
38+
"Variable type in send side should be in "
39+
"[LodTensor, SelectedRows]");
40+
}
41+
return false;
42+
}
2743

2844
class SendOp : public framework::OperatorBase {
2945
public:
@@ -51,8 +67,12 @@ class SendOp : public framework::OperatorBase {
5167
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
5268

5369
for (size_t i = 0; i < ins.size(); i++) {
54-
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
55-
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
70+
if (IsVariableInitialized(scope, ins[i])) {
71+
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
72+
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
73+
} else {
74+
VLOG(3) << "don't send no-initialied variable: " << ins[i];
75+
}
5676
}
5777
PADDLE_ENFORCE(rpc_client->Wait());
5878

paddle/fluid/operators/split_selected_rows_op.cc

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
2222
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
2323
: OpProtoAndCheckerMaker(proto, op_checker) {
2424
AddInput("X", "The input SelectedRows.");
25-
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable();
25+
AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable();
2626
AddAttr<std::vector<int>>("height_sections",
2727
"Height for each output SelectedRows.")
2828
.SetDefault(std::vector<int>({}));
@@ -56,27 +56,6 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
5656
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
5757
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
5858
"SplitSelectedRowsOp must has output Out.");
59-
60-
std::vector<int> height_sections =
61-
ctx->Attrs().Get<std::vector<int>>("height_sections");
62-
int64_t n = ctx->Outputs("Out").size();
63-
64-
std::vector<framework::DDim> outs_dims;
65-
outs_dims.reserve(n);
66-
67-
// make output dims
68-
for (int64_t i = 0; i < n; ++i) {
69-
auto dims = ctx->GetInputDim("X");
70-
if (height_sections.size()) {
71-
PADDLE_ENFORCE_EQ(
72-
height_sections.size(), static_cast<size_t>(n),
73-
"The size of height section should be the same with height"
74-
" section size.");
75-
dims[0] = height_sections[i];
76-
}
77-
outs_dims.push_back(dims);
78-
}
79-
ctx->SetOutputsDim("Out", outs_dims);
8059
}
8160
};
8261

paddle/fluid/operators/split_selected_rows_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
5555

5656
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
5757
auto rows_idx = outs_rows_idx[i];
58+
outs[i]->set_height(height_sections[i]);
5859
if (rows_idx.size() > 0) {
5960
auto dims = x->GetCompleteDims();
6061
dims[0] = rows_idx.size();

paddle/fluid/operators/sum_op.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ class SumKernel : public framework::OpKernel<T> {
116116
int64_t offset = 0;
117117
for (int i = 0; i < N; i++) {
118118
auto &sel_row = get_selected_row(i);
119-
119+
if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
120+
continue;
121+
}
120122
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
121123
functor(context.template device_context<DeviceContext>(), sel_row,
122124
offset, out);

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def transpile(self,
191191
for b in param_blocks:
192192
varname, block_id, _ = b.split(":")
193193
send_outputs.append(param_var_mapping[varname][int(block_id)])
194+
194195
# let send_op know which endpoint to send which var to, eplist has the same
195196
# order as send_inputs.
196197
eplist = split_method(send_inputs, pserver_endpoints)
@@ -274,6 +275,7 @@ def _create_vars_from_blocklist(self, program, block_list):
274275
name="%s.block%d" % (varname, i),
275276
psersistable=False,
276277
dtype=orig_var.dtype,
278+
type=orig_var.type,
277279
shape=splited_shape) # flattend splited var
278280
var_mapping[varname].append(var)
279281
return var_mapping
@@ -335,6 +337,7 @@ def _create_var_for_trainers(self, block, var, trainers):
335337
name="%s.trainer_%d" % (var.name, i),
336338
psersistable=var.persistable,
337339
dtype=var.dtype,
340+
type=var.type,
338341
shape=var.shape)
339342
var_list.append(var_each)
340343
return var_list
@@ -561,6 +564,7 @@ def get_pserver_program(self, endpoint):
561564
persistable=True,
562565
dtype=v.dtype,
563566
shape=v.shape)
567+
564568
# step6
565569
optimize_block = pserver_program.create_block(0)
566570
# step 6.1

0 commit comments

Comments
 (0)