Skip to content

Commit 88d79df

Browse files
authored
Merge pull request #10292 from typhoonzero/fix_grpc_server_ready_condition
Fix grpc server ready condition
2 parents 6084af4 + 6422c0e commit 88d79df

File tree

6 files changed

+66
-22
lines changed

6 files changed

+66
-22
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ void AsyncGRPCServer::WaitClientGet(int count) {
211211
}
212212
}
213213

214+
void AsyncGRPCServer::WaitServerReady() {
215+
std::unique_lock<std::mutex> lock(this->mutex_ready_);
216+
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
217+
}
218+
214219
void AsyncGRPCServer::RunSyncUpdate() {
215220
::grpc::ServerBuilder builder;
216221
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
@@ -244,6 +249,12 @@ void AsyncGRPCServer::RunSyncUpdate() {
244249
t_prefetch_.reset(new std::thread(
245250
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
246251
"cq_prefetch", prefetch_register)));
252+
253+
{
254+
std::lock_guard<std::mutex> lock(this->mutex_ready_);
255+
ready_ = 1;
256+
}
257+
condition_ready_.notify_all();
247258
// wait server
248259
server_->Wait();
249260
t_send_->join();

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ class RequestBase;
4545
class AsyncGRPCServer final {
4646
public:
4747
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
48-
: address_(address), sync_mode_(sync_mode) {}
48+
: address_(address), sync_mode_(sync_mode), ready_(0) {}
4949

50+
void WaitServerReady();
5051
void RunSyncUpdate();
5152

5253
// functions to sync server barrier status.
@@ -118,6 +119,10 @@ class AsyncGRPCServer final {
118119
framework::ProgramDesc *program_;
119120
framework::Executor *executor_;
120121
int selected_port_;
122+
123+
std::mutex mutex_ready_;
124+
std::condition_variable condition_ready_;
125+
int ready_;
121126
};
122127

123128
}; // namespace detail

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,35 @@ static void ParallelExecuteBlocks(
6666
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
6767
}
6868

69-
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
70-
std::ofstream port_file;
71-
port_file.open("/tmp/paddle.selected_port");
72-
port_file << rpc_service->GetSelectedPort();
73-
port_file.close();
74-
}
69+
std::atomic_int ListenAndServOp::selected_port_{0};
7570

7671
ListenAndServOp::ListenAndServOp(const std::string &type,
7772
const framework::VariableNameMap &inputs,
7873
const framework::VariableNameMap &outputs,
7974
const framework::AttributeMap &attrs)
8075
: OperatorBase(type, inputs, outputs, attrs) {}
8176

82-
int ListenAndServOp::GetSelectedPort() const {
83-
return rpc_service_->GetSelectedPort();
84-
}
85-
8677
void ListenAndServOp::Stop() {
8778
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
8879
server_thread_->join();
8980
}
9081

82+
void ListenAndServOp::SavePort(const std::string &file_path) const {
83+
// NOTE: default write file to /tmp/paddle.selected_port
84+
selected_port_ = rpc_service_->GetSelectedPort();
85+
86+
std::ofstream port_file;
87+
port_file.open(file_path);
88+
port_file << selected_port_.load();
89+
port_file.close();
90+
VLOG(4) << "selected port written to " << file_path;
91+
}
92+
93+
void ListenAndServOp::WaitServerReady() {
94+
while (selected_port_.load() == 0) {
95+
}
96+
}
97+
9198
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
9299
framework::ProgramDesc *program,
93100
framework::Scope *recv_scope,
@@ -318,9 +325,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
318325
// start the server listening after all member initialized.
319326
server_thread_.reset(new std::thread(RunServer, rpc_service_));
320327
VLOG(3) << "wait server thread to become ready...";
321-
sleep(5);
328+
rpc_service_->WaitServerReady();
329+
322330
// Write to a file of server selected port for python use.
323-
SavePort(rpc_service_);
331+
std::string file_path =
332+
string::Sprintf("/tmp/paddle.%d.selected_port",
333+
static_cast<int>(::getpid()));
334+
SavePort(file_path);
324335
if (sync_mode) {
325336
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
326337
} else {

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <stdint.h>
18+
#include <atomic>
1819
#include <ostream>
1920
#include <string>
2021

@@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase {
3940
const framework::VariableNameMap& outputs,
4041
const framework::AttributeMap& attrs);
4142

42-
int GetSelectedPort() const;
43-
4443
void RunSyncLoop(framework::Executor* executor,
4544
framework::ProgramDesc* program,
4645
framework::Scope* recv_scope,
@@ -49,14 +48,25 @@ class ListenAndServOp : public framework::OperatorBase {
4948
void RunAsyncLoop(framework::Executor* executor,
5049
framework::ProgramDesc* program) const;
5150

51+
void SavePort(
52+
const std::string& file_path = "/tmp/paddle.selected_port") const;
53+
54+
void WaitServerReady();
55+
56+
int GetSelectedPort() { return selected_port_; }
57+
5258
void Stop() override;
5359

5460
void RunImpl(const framework::Scope& scope,
5561
const platform::Place& dev_place) const override;
5662

63+
static void ResetPort() { selected_port_ = 0; }
64+
5765
protected:
5866
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
5967
mutable std::shared_ptr<std::thread> server_thread_;
68+
// FIXME(wuyi): it's static so that the operator can be cloned.
69+
static std::atomic_int selected_port_;
6070
};
6171

6272
} // namespace operators

paddle/fluid/operators/send_recv_op_test.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
116116
void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
117117
f::Scope scope;
118118
p::CPUPlace place;
119+
VLOG(4) << "before init tensor";
119120
if (is_sparse) {
120121
InitSelectedRowsInScope(place, &scope);
121122
} else {
@@ -137,6 +138,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
137138
attrs.insert({"PrefetchBlock", prefetch_block});
138139
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
139140
attrs.insert({"sync_mode", true});
141+
VLOG(4) << "before init op";
140142
listen_and_serv_op =
141143
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
142144
*initialized = true;
@@ -149,7 +151,9 @@ TEST(SendRecvOp, CPUDense) {
149151
std::thread server_thread(StartServerNet, false, &initialized);
150152
while (!initialized) {
151153
}
152-
sleep(5); // wait server to start
154+
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
155+
->WaitServerReady();
156+
153157
// local net
154158
f::Scope scope;
155159
p::CPUPlace place;
@@ -185,6 +189,7 @@ TEST(SendRecvOp, CPUDense) {
185189
listen_and_serv_op->Stop();
186190
server_thread.join();
187191
listen_and_serv_op.reset(nullptr);
192+
paddle::operators::ListenAndServOp::ResetPort();
188193
}
189194

190195
TEST(SendRecvOp, CPUSparse) {
@@ -193,18 +198,19 @@ TEST(SendRecvOp, CPUSparse) {
193198
std::thread server_thread(StartServerNet, true, &initialized);
194199
while (!initialized) {
195200
}
196-
sleep(5); // wait server to start
201+
auto *listen_and_serv_op_ptr =
202+
static_cast<paddle::operators::ListenAndServOp *>(
203+
listen_and_serv_op.get());
204+
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
205+
listen_and_serv_op_ptr->WaitServerReady();
206+
197207
// local net
198208
f::Scope scope;
199209
p::CPUPlace place;
200210
p::CPUDeviceContext ctx(place);
201211
InitSelectedRowsInScope(place, &scope);
202212
scope.Var("RPC_CLIENT_VAR");
203213
f::AttributeMap attrs;
204-
auto *listen_and_serv_op_ptr =
205-
static_cast<paddle::operators::ListenAndServOp *>(
206-
listen_and_serv_op.get());
207-
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
208214
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
209215
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
210216
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
@@ -236,4 +242,5 @@ TEST(SendRecvOp, CPUSparse) {
236242
listen_and_serv_op->Stop();
237243
server_thread.join();
238244
listen_and_serv_op.reset();
245+
paddle::operators::ListenAndServOp::ResetPort();
239246
}

python/paddle/fluid/tests/unittests/test_dist_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_send(self):
3434
p.start()
3535

3636
time.sleep(10)
37-
with open("/tmp/paddle.selected_port", "r") as fn:
37+
with open("/tmp/paddle.%d.selected_port" % p.pid, "r") as fn:
3838
selected_port = int(fn.readlines()[0])
3939
self.init_client(place, selected_port)
4040

0 commit comments

Comments
 (0)