Skip to content

Commit 416922e

Browse files
authored
Distributed training cherry-pick for Release 1.5 (#19486)
* fix bug in Class MultiSlotDataGenerator's function _gen_str, test=develop (#18222) * fix some bug when merge sparse embedding parameters, test=develop (#18223) * fix communicator with pyreader (#18350) * delete AllocatorFacade destructor (#18606) * fix distribute transpiler GRPC error code 4, RPC Deadline (#18984) * merge pr #18441
1 parent 0edeb83 commit 416922e

32 files changed

+631
-343
lines changed

paddle/fluid/framework/details/async_ssa_graph_executor.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
8787
// init communicator here
8888
if (send_varname_to_ctx.size() > 0) {
8989
VLOG(3) << "this is distribute mode, will use communicator";
90-
operators::distributed::Communicator::Init(send_varname_to_ctx,
91-
recv_varname_to_ctx, scope);
92-
operators::distributed::Communicator::GetInstance()->Start();
90+
91+
if (operators::distributed::Communicator::GetInstance() == nullptr) {
92+
operators::distributed::Communicator::Init(send_varname_to_ctx,
93+
recv_varname_to_ctx, scope);
94+
operators::distributed::Communicator::GetInstance()->Start();
95+
} else {
96+
VLOG(3) << "communicator has been initialized, skip";
97+
}
9398
}
9499
#endif
95100
}

paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,6 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
133133
VLOG(1) << "set recv op do_not_run to true";
134134
node->Op()->SetAttr("do_not_run", 1);
135135
node->Op()->Flush();
136-
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
137-
node->Name() == "hierarchical_sigmoid") {
138-
// in async_mode, we do not need remote prefetch, because communicator
139-
// will do async parameter recv.
140-
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
141-
node->Op()->SetAttr("remote_prefetch", false);
142-
node->Op()->Flush();
143136
}
144137
return false;
145138
}

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ class ExecutionContext {
248248
return op_.Attr<T>(name);
249249
}
250250

251+
bool HasAttr(const std::string& name) const { return op_.HasAttr(name); }
252+
251253
bool HasInput(const std::string& name) const;
252254

253255
bool HasOutput(const std::string& name) const;

paddle/fluid/memory/allocation/allocator_facade.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ class AllocatorFacadePrivate {
295295

296296
// Pimpl. Make interface clean.
297297
AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {}
298-
AllocatorFacade::~AllocatorFacade() { delete m_; }
298+
// delete m_ may cause core dump when the destructor of python in conflict with
299+
// cpp.
300+
AllocatorFacade::~AllocatorFacade() {}
299301

300302
AllocatorFacade& AllocatorFacade::Instance() {
301303
static AllocatorFacade instance;

paddle/fluid/operators/distributed/communicator.cc

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,26 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
7373
VLOG(0) << "communicator_max_merge_var_num: "
7474
<< FLAGS_communicator_max_merge_var_num;
7575
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
76-
send_scope_.reset(new Scope());
77-
for (auto &iter : send_varname_to_ctx_) {
78-
send_varname_to_queue_[iter.first] =
79-
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
80-
FLAGS_communicator_send_queue_size);
76+
77+
if (send_varname_to_ctx.size() == 0) {
78+
VLOG(0) << "nothing need to be send, will not start send_thread";
79+
} else {
80+
send_scope_.reset(new Scope());
81+
for (auto &iter : send_varname_to_ctx_) {
82+
send_varname_to_queue_[iter.first] =
83+
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
84+
FLAGS_communicator_send_queue_size);
85+
}
86+
send_threadpool_.reset(
87+
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
88+
}
89+
90+
if (recv_varname_to_ctx.size() == 0) {
91+
VLOG(0) << "nothing need to be received, will not start recv_thread";
92+
} else {
93+
recv_threadpool_.reset(
94+
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
8195
}
82-
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
83-
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
8496
}
8597

8698
Communicator::~Communicator() {
@@ -157,18 +169,28 @@ void Communicator::SendThread() {
157169
task_f.wait();
158170
}
159171
auto after_run_send_graph = GetCurrentUS();
160-
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
161-
if (send_graph_use_time > 100) {
162-
VLOG(1) << "run send graph use time "
163-
<< after_run_send_graph - before_run_send_graph;
164-
}
165-
if (!FLAGS_communicator_independent_recv_thread) {
166-
RecvAll();
167-
}
172+
173+
VLOG(3) << "run send graph use time "
174+
<< after_run_send_graph - before_run_send_graph;
175+
RecvNonIndependent();
168176
}
169177
VLOG(0) << "communicator stopped, send thread exit";
170178
}
171179

180+
void Communicator::RecvNonIndependent() {
181+
if (!FLAGS_communicator_independent_recv_thread) {
182+
return;
183+
}
184+
185+
auto grad_num = grad_num_.load();
186+
if (grad_num > 0) {
187+
RecvAll();
188+
grad_num_.store(0);
189+
} else {
190+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
191+
}
192+
}
193+
172194
void Communicator::RecvAll() {
173195
VLOG(3) << "parallel run recv graph";
174196
if (!running_) return;

paddle/fluid/operators/distributed/communicator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,15 @@ class Communicator {
167167
void Start();
168168
void Stop();
169169

170+
bool IsRunning() { return running_; }
171+
170172
// send grad
171173
void Send(const std::string& var_name, const framework::Scope& scope);
172174

173175
private:
174176
// recv all parameter
175177
void RecvAll();
178+
void RecvNonIndependent();
176179
void SendThread();
177180
void RecvThread();
178181

0 commit comments

Comments
 (0)