Skip to content

Commit 781d284

Browse files
authored
Optimize decay (#20816) (#20952)
* update pserver decay blocks * update distributed notify handler
1 parent 55c2329 commit 781d284

17 files changed

+399
-194
lines changed

paddle/fluid/framework/details/async_ssa_graph_executor.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
6262
node->Op()->GetNullableAttr("sections"));
6363
auto trainer_id =
6464
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
65+
auto merge_add =
66+
boost::get<bool>(node->Op()->GetNullableAttr("merge_add"));
67+
if (!merge_add) {
68+
merge_add = FLAGS_communicator_is_sgd_optimizer;
69+
}
70+
auto use_send_handler =
71+
boost::get<bool>(node->Op()->GetNullableAttr("use_send_handler"));
6572
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
66-
send_var_name, send_varnames, epmap, height_section, trainer_id);
73+
send_var_name, send_varnames, epmap, height_section, trainer_id,
74+
merge_add, use_send_handler);
6775
VLOG(3) << "find and init an send op: "
6876
<< send_varname_to_ctx[send_var_name];
6977
} else if (node->Name() == "recv") {

paddle/fluid/operators/distributed/communicator.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
130130
auto height_section =
131131
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
132132
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
133+
auto merge_add = boost::get<bool>(op->GetNullableAttr("merge_add"));
134+
if (!merge_add) {
135+
merge_add = FLAGS_communicator_is_sgd_optimizer;
136+
}
137+
auto use_send_handler =
138+
boost::get<bool>(op->GetNullableAttr("use_send_handler"));
133139
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
134-
send_var_name, send_varnames, epmap, height_section, trainer_id);
140+
send_var_name, send_varnames, epmap, height_section, trainer_id,
141+
merge_add, use_send_handler);
135142
VLOG(3) << "find and init an send op: "
136143
<< send_varname_to_ctx[send_var_name];
137144
} else if (op->Type() == "recv") {
@@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() {
208215
}
209216
}
210217
auto before_merge = GetCurrentUS();
211-
MergeVars(var_name, vars, send_scope_.get());
218+
auto &ctx = send_varname_to_ctx_.at(var_name);
219+
if (ctx.use_send_handler) {
220+
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
221+
} else {
222+
MergeVars<int64_t>(var_name, vars, send_scope_.get(),
223+
ctx.merge_add);
224+
}
212225
auto after_merge = GetCurrentUS();
213226
VLOG(3) << "merge " << merged_var_num << " " << var_name
214227
<< " use time " << after_merge - before_merge;
215228
auto send_functor = distributed::ParameterSend<float>();
216-
auto &ctx = send_varname_to_ctx_.at(var_name);
217229
if (!FLAGS_communicator_fake_rpc) {
218230
send_functor(ctx, *send_scope_, true, 1);
219231
}

paddle/fluid/operators/distributed/communicator.h

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
107107
typename IndexType = Eigen::DenseIndex>
108108
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
109109

110+
template <typename T>
110111
inline void MergeVars(const std::string& var_name,
111112
const std::vector<std::shared_ptr<Variable>>& vars,
112-
Scope* scope) {
113+
Scope* scope, bool merge_add = true) {
113114
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
114115
auto cpu_place = platform::CPUPlace();
115116
auto& var0 = vars[0];
116117
auto* out_var = scope->Var(var_name);
117118
if (var0->IsType<framework::LoDTensor>()) {
118119
auto dims = var0->Get<framework::LoDTensor>().dims();
119-
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims;
120-
120+
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
121+
<< "; merge add: " << merge_add;
121122
// init output tensor
122123
auto* out_t = out_var->GetMutable<framework::LoDTensor>();
123-
out_t->mutable_data<float>(dims, cpu_place);
124-
124+
out_t->mutable_data<T>(dims, cpu_place);
125125
// check the input dims
126126
for (auto& var : vars) {
127127
auto& var_t = var->Get<framework::LoDTensor>();
@@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name,
130130

131131
// set output tensor to 0.
132132
auto cpu_ctx = paddle::platform::CPUDeviceContext();
133-
math::SetConstant<paddle::platform::CPUDeviceContext, float>
134-
constant_functor;
135-
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
136-
133+
math::SetConstant<paddle::platform::CPUDeviceContext, T> constant_functor;
134+
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
137135
// sum all vars to out
138-
auto result = EigenVector<float>::Flatten(*out_t);
136+
auto result = EigenVector<T>::Flatten(*out_t);
139137
for (auto& var : vars) {
140138
auto& in_t = var->Get<framework::LoDTensor>();
141-
auto in = EigenVector<float>::Flatten(in_t);
139+
auto in = EigenVector<T>::Flatten(in_t);
142140
result.device(*cpu_ctx.eigen_device()) = result + in;
143141
}
144-
if (!FLAGS_communicator_is_sgd_optimizer) {
142+
if (!merge_add) {
145143
result.device(*cpu_ctx.eigen_device()) =
146-
result / static_cast<float>(vars.size());
144+
result / static_cast<T>(vars.size());
147145
}
148146
} else if (var0->IsType<framework::SelectedRows>()) {
149147
auto& slr0 = var0->Get<framework::SelectedRows>();
150148
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
151149
out_slr->mutable_rows()->clear();
152-
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
150+
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
153151
std::vector<const paddle::framework::SelectedRows*> inputs;
154152
inputs.reserve(vars.size());
155153
for (auto& var : vars) {
156154
inputs.push_back(&var->Get<framework::SelectedRows>());
157155
}
158156
auto dev_ctx = paddle::platform::CPUDeviceContext();
159-
if (FLAGS_communicator_is_sgd_optimizer) {
160-
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
161-
merge_add;
157+
if (merge_add) {
158+
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, T> merge_add;
162159
merge_add(dev_ctx, inputs, out_slr);
163160
} else {
164-
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
161+
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, T>
165162
merge_average;
166163
merge_average(dev_ctx, inputs, out_slr);
167164
}
168165

169166
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
170-
<< " dims: " << slr0.value().dims();
167+
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
171168
} else {
172169
PADDLE_THROW("unsupported var type!");
173170
}

paddle/fluid/operators/distributed/communicator_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) {
4747
scope.reset(new framework::Scope());
4848
scope->Var(out_name);
4949
for (auto i = 0; i < 10; ++i) {
50-
MergeVars(out_name, in_vars, scope.get());
50+
MergeVars<float>(out_name, in_vars, scope.get());
5151
}
5252
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
5353
auto *out_data = out_tensor.data<float>();
@@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) {
8686
scope.reset(new framework::Scope());
8787
scope->Var(out_name);
8888
for (auto i = 0; i < 10; ++i) {
89-
MergeVars(out_name, in_vars, scope.get());
89+
MergeVars<float>(out_name, in_vars, scope.get());
9090
}
9191
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
9292
auto &out_t = out_slr.value();

paddle/fluid/operators/distributed/grpc/grpc_client.cc

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
438438
return h;
439439
}
440440

441-
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep,
442-
const std::string& type,
443-
int64_t time_out) {
444-
const auto ch = GetChannel(ep);
445-
446-
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch);
447-
441+
VarHandlePtr GRPCClient::AsyncDistributeNotify(
442+
const std::string& ep, const platform::DeviceContext& ctx,
443+
const framework::Scope& scope, const std::string& var_name,
444+
int64_t time_out) {
445+
const platform::DeviceContext* p_ctx = &ctx;
446+
const std::string ep_val = ep;
447+
const std::string var_name_val = var_name;
448+
const framework::Scope* p_scope = &scope;
449+
const auto ch = GetChannel(ep_val);
448450
const std::string method = kRequestNotify;
449451

450-
VarHandlePtr h(
451-
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr));
452+
SendProcessor* s = new SendProcessor(ch);
453+
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
452454
s->Prepare(h, time_out);
453455

454-
sendrecv::VariableMessage req;
455-
req.set_varname(type);
456+
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
457+
auto* var = p_scope->FindVar(var_name_val);
456458

457-
platform::RecordRPCEvent record_event(method);
459+
::grpc::ByteBuffer req;
460+
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
458461

459-
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_);
460-
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
462+
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
463+
464+
// stub context
465+
s->response_call_back_ = nullptr;
466+
467+
platform::RecordRPCEvent record_event(method);
468+
469+
auto call = s->stub_g_.PrepareUnaryCall(
470+
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
471+
&cq_);
472+
call->StartCall();
473+
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
474+
});
461475
req_count_++;
462476

463477
if (UNLIKELY(platform::IsProfileEnabled())) {

paddle/fluid/operators/distributed/grpc/grpc_client.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor {
173173
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
174174
};
175175

176-
class DistributeNotifyProcessor : public BaseProcessor {
177-
public:
178-
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
179-
: BaseProcessor() {
180-
stub_ = sendrecv::SendRecvService::NewStub(ch);
181-
}
182-
183-
virtual ~DistributeNotifyProcessor() {}
184-
185-
void ProcessImpl() override {}
186-
sendrecv::VoidMessage reply_;
187-
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
188-
};
189-
190176
class GRPCClient : public RPCClient {
191177
public:
192178
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
@@ -240,7 +226,8 @@ class GRPCClient : public RPCClient {
240226
int64_t time_out = FLAGS_rpc_deadline) override;
241227

242228
VarHandlePtr AsyncDistributeNotify(
243-
const std::string& ep, const std::string& type,
229+
const std::string& ep, const platform::DeviceContext& ctx,
230+
const framework::Scope& scope, const std::string& var_name,
244231
int64_t time_out = FLAGS_rpc_deadline) override;
245232

246233
VarHandlePtr AsyncSendComplete(

paddle/fluid/operators/distributed/grpc/grpc_server.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -400,33 +400,31 @@ class RequestNotify final : public RequestBase {
400400
RequestHandler* request_handler, int req_id)
401401
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
402402
request_.reset(new GRPCVariableResponse(request_handler->scope(),
403-
request_handler->dev_ctx()));
403+
request_handler->dev_ctx(),
404+
!request_handler->sync_mode()));
404405
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
405406
service_->RequestAsyncUnary(
406407
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
407408
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
408409
}
409-
410410
virtual ~RequestNotify() {}
411-
412411
std::string GetReqName() override { return request_->Varname(); }
413412

414413
void Process() override {
415-
auto scope = request_->GetMutableLocalScope();
414+
std::string varname = GetReqName();
415+
VLOG(4) << "RequestNotify var_name:" << varname;
416416

417-
std::string varname = request_->Varname();
417+
auto scope = request_->GetMutableLocalScope();
418+
auto invar = request_->GetVar();
418419
int trainer_id = request_->GetTrainerId();
419-
420-
VLOG(4) << "RequestNotify notify: " << varname
421-
<< ", trainer id: " << trainer_id;
422-
423-
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
420+
framework::Variable* outvar = nullptr;
421+
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
424422
Finish(reply_, &responder_);
425423
}
426424

427425
protected:
428-
std::shared_ptr<GRPCVariableResponse> request_;
429426
sendrecv::VoidMessage reply_;
427+
std::shared_ptr<GRPCVariableResponse> request_;
430428
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
431429
};
432430

paddle/fluid/operators/distributed/parameter_send.cc

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,44 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
116116
row_offset += outs_dims[i][0];
117117
}
118118
}
119-
120-
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
121-
auto &send_var_name = rpc_ctx.splited_var_names[i];
122-
VLOG(4) << "send var name: " << send_var_name;
123-
auto &endpoint = rpc_ctx.epmap[i];
124-
VLOG(4) << "send var endpoint: " << endpoint;
125-
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
126-
if (NeedSend(*local_scope.get(), send_var_name)) {
127-
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
128-
rets.push_back(rpc_client->AsyncSendVar(
129-
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
130-
VLOG(4) << "send var " << send_var_name << " async handle done";
131-
} else {
132-
VLOG(3) << "don't send non-initialized variable: "
133-
<< rpc_ctx.splited_var_names[i];
119+
if (rpc_ctx.use_send_handler) {
120+
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
121+
auto &send_var_name = rpc_ctx.splited_var_names[i];
122+
VLOG(4) << "send var name: " << send_var_name;
123+
auto &endpoint = rpc_ctx.epmap[i];
124+
VLOG(4) << "send var endpoint: " << endpoint;
125+
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
126+
if (NeedSend(*local_scope.get(), send_var_name)) {
127+
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
128+
rets.push_back(rpc_client->AsyncSendVar(
129+
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
130+
VLOG(4) << "send var " << send_var_name << " async handle done";
131+
} else {
132+
VLOG(3) << "don't send non-initialized variable: "
133+
<< rpc_ctx.splited_var_names[i];
134+
}
135+
}
136+
} else {
137+
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
138+
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
139+
auto &send_var_name = rpc_ctx.splited_var_names[i];
140+
VLOG(4) << "send var name: " << send_var_name;
141+
auto &endpoint = rpc_ctx.epmap[j];
142+
VLOG(4) << "send var endpoint: " << endpoint;
143+
VLOG(4) << "need send: "
144+
<< NeedSend(*local_scope.get(), send_var_name);
145+
if (NeedSend(*local_scope.get(), send_var_name)) {
146+
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
147+
rets.push_back(rpc_client->AsyncDistributeNotify(
148+
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
149+
VLOG(4) << "send var " << send_var_name << " async handle done";
150+
} else {
151+
VLOG(3) << "don't send non-initialized variable: "
152+
<< rpc_ctx.splited_var_names[i];
153+
}
154+
}
134155
}
135156
}
136-
137157
} else if (send_var->IsType<framework::SelectedRows>()) {
138158
auto &send_slr = send_var->Get<framework::SelectedRows>();
139159
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);

paddle/fluid/operators/distributed/request_handler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
6363
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
6464
#define COMPLETE_MESSAGE "COMPLETE@RECV"
6565
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
66-
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
66+
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
6767

6868
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
6969
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname,
262262
const int trainer_id,
263263
const std::string& out_var_name,
264264
const std::string& table_name) {
265-
VLOG(4) << "RequestNotifyHandler" << varname;
266-
if (varname == LEARNING_RATE_DECAY_MESSAGE) {
265+
VLOG(4) << "RequestNotifyHandler: " << varname;
266+
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
267+
268+
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
269+
string::Piece var_name_piece = string::Piece(varname);
270+
if (string::Contains(var_name_piece, decay_piece)) {
271+
VLOG(3) << "LearningRate Decay Counter Update";
267272
PADDLE_ENFORCE_NE(
268273
lr_decay_block_id, -1,
269274
"when lr_decay_block_id = -1, there should be no RPC invoke.");
275+
auto* origin_var = scope_->FindVar(varname);
276+
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
277+
auto* send_var = scope->FindVar(varname);
278+
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
279+
int64_t* origin_value =
280+
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
281+
int64_t* send_value =
282+
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
283+
origin_value[0] += send_value[0];
270284
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
271285
}
272286
return true;

0 commit comments

Comments
 (0)