Skip to content

Commit d139f2c

Browse files
authored
Merge pull request #9595 from typhoonzero/fix_test_sendrecv_portbind
Fix sendrecv port bind
2 parents 5641859 + b03fa88 commit d139f2c

File tree

6 files changed

+209
-144
lines changed

6 files changed

+209
-144
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ if(WITH_DISTRIBUTE)
193193
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
194194
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
195195
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
196+
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
196197
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
197198
else()
198199
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
186186

187187
void AsyncGRPCServer::RunSyncUpdate() {
188188
::grpc::ServerBuilder builder;
189-
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
189+
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
190+
&selected_port_);
190191
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
191192
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
192193
builder.RegisterService(&service_);
@@ -196,7 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() {
196197
cq_prefetch_ = builder.AddCompletionQueue();
197198

198199
server_ = builder.BuildAndStart();
199-
LOG(INFO) << "Server listening on " << address_ << std::endl;
200+
LOG(INFO) << "Server listening on " << address_
201+
<< " selected port: " << selected_port_;
200202

201203
std::function<void()> send_register =
202204
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class AsyncGRPCServer final {
6363

6464
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
6565

66+
int GetSelectedPort() { return selected_port_; }
67+
6668
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
6769

6870
void Push(const std::string &msg_name) {
@@ -111,6 +113,7 @@ class AsyncGRPCServer final {
111113
int prefetch_blk_id_;
112114
framework::ProgramDesc *program_;
113115
framework::Executor *executor_;
116+
int selected_port_;
114117
};
115118

116119
}; // namespace detail

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 126 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <stdint.h>
1615
#include <ostream>
16+
#include <thread>
1717

18-
#include "paddle/fluid/framework/executor.h"
19-
#include "paddle/fluid/framework/lod_tensor.h"
20-
#include "paddle/fluid/framework/op_registry.h"
21-
#include "paddle/fluid/framework/threadpool.h"
22-
#include "paddle/fluid/operators/detail/grpc_server.h"
18+
#include "paddle/fluid/operators/listen_and_serv_op.h"
2319

2420
namespace paddle {
2521
namespace operators {
2622

27-
constexpr char kOptimizeBlock[] = "OptimizeBlock";
28-
2923
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
3024
service->RunSyncUpdate();
3125
VLOG(4) << "RunServer thread end";
@@ -66,143 +60,138 @@ static void ParallelExecuteBlocks(
6660
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
6761
}
6862

69-
class ListenAndServOp : public framework::OperatorBase {
70-
public:
71-
ListenAndServOp(const std::string &type,
72-
const framework::VariableNameMap &inputs,
73-
const framework::VariableNameMap &outputs,
74-
const framework::AttributeMap &attrs)
75-
: OperatorBase(type, inputs, outputs, attrs) {
76-
if (!rpc_service_) {
77-
std::string endpoint = Attr<std::string>("endpoint");
78-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
79-
server_thread_.reset(new std::thread(RunServer, rpc_service_));
80-
}
81-
}
63+
ListenAndServOp::ListenAndServOp(const std::string &type,
64+
const framework::VariableNameMap &inputs,
65+
const framework::VariableNameMap &outputs,
66+
const framework::AttributeMap &attrs)
67+
: OperatorBase(type, inputs, outputs, attrs) {}
8268

83-
void Stop() override {
84-
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
85-
server_thread_->join();
69+
int ListenAndServOp::GetSelectedPort() {
70+
return rpc_service_->GetSelectedPort();
71+
}
72+
73+
void ListenAndServOp::Stop() {
74+
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
75+
server_thread_->join();
76+
}
77+
78+
void ListenAndServOp::RunImpl(const framework::Scope &scope,
79+
const platform::Place &dev_place) const {
80+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
81+
auto &dev_ctx = *pool.Get(dev_place);
82+
framework::Scope &recv_scope = scope.NewScope();
83+
84+
if (!rpc_service_) {
85+
std::string endpoint = Attr<std::string>("endpoint");
86+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
8687
}
8788

88-
void RunImpl(const framework::Scope &scope,
89-
const platform::Place &dev_place) const override {
90-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
91-
auto &dev_ctx = *pool.Get(dev_place);
92-
framework::Scope &recv_scope = scope.NewScope();
93-
94-
// FIXME(Yancey1989): initialize rpc server with lazy mode.
95-
rpc_service_->SetScope(&recv_scope);
96-
rpc_service_->SetDevCtx(&dev_ctx);
97-
auto ins = Inputs("X");
98-
auto fan_in = Attr<int>("Fanin");
99-
100-
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
101-
auto *program = block->Program();
102-
size_t num_blocks = program->Size();
103-
PADDLE_ENFORCE_GE(num_blocks, 2,
104-
"server program should have at least 2 blocks");
105-
106-
framework::Executor executor(dev_place);
107-
std::vector<int> block_list;
108-
for (size_t blkid = 1; blkid < num_blocks; ++blkid)
109-
block_list.push_back(blkid);
110-
auto prepared = executor.Prepare(*program, block_list);
111-
prepared.insert(
112-
prepared.begin(),
113-
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
114-
115-
// TODO(qiao) set proper fields for table lookup and update
116-
rpc_service_->SetExecutor(&executor);
117-
rpc_service_->SetPrefetchBlkdId(0);
118-
rpc_service_->SetProgram(program);
119-
120-
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
121-
bool exit_flag = false;
122-
// Record received sparse variables, so that
123-
// we could reset those after execute optimize program
124-
std::vector<framework::Variable *> sparse_vars;
125-
while (!exit_flag) {
126-
// Get from multiple trainers, we don't care about the order in which
127-
// the gradients arrives, just add suffix 0~n and merge the gradient.
128-
rpc_service_->SetCond(0);
129-
size_t recv_var_cnt = 0;
130-
int batch_barrier = 0;
131-
while (batch_barrier != fan_in) {
132-
const detail::ReceivedMessage v = rpc_service_->Get();
133-
auto recv_var_name = v.first;
134-
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
135-
LOG(INFO) << "received terminate message and exit";
136-
exit_flag = true;
137-
break;
138-
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
139-
VLOG(3) << "recv batch barrier message";
140-
batch_barrier++;
141-
continue;
142-
} else {
143-
VLOG(3) << "received grad: " << recv_var_name;
144-
recv_var_cnt++;
145-
auto var = v.second->GetVar();
146-
if (var == nullptr) {
147-
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
148-
PADDLE_THROW("Can not find server side var");
149-
}
150-
if (var->IsType<framework::SelectedRows>()) {
151-
sparse_vars.push_back(var);
152-
}
153-
}
154-
}
155-
if (exit_flag) {
156-
rpc_service_->SetCond(1);
157-
rpc_service_->ShutDown();
89+
auto ins = Inputs("X");
90+
auto fan_in = Attr<int>("Fanin");
91+
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
92+
auto *program = block->Program();
93+
size_t num_blocks = program->Size();
94+
PADDLE_ENFORCE_GE(num_blocks, 2,
95+
"server program should have at least 2 blocks");
96+
97+
framework::Executor executor(dev_place);
98+
std::vector<int> block_list;
99+
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
100+
block_list.push_back(blkid);
101+
}
102+
auto prepared = executor.Prepare(*program, block_list);
103+
// Insert placeholder for block0 which holds current op itself.
104+
prepared.insert(prepared.begin(),
105+
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
106+
107+
rpc_service_->SetScope(&recv_scope);
108+
rpc_service_->SetDevCtx(&dev_ctx);
109+
// TODO(qiao) set proper fields for table lookup and update
110+
rpc_service_->SetExecutor(&executor);
111+
rpc_service_->SetPrefetchBlkdId(0);
112+
rpc_service_->SetProgram(program);
113+
// start the server listening after all member initialized.
114+
server_thread_.reset(new std::thread(RunServer, rpc_service_));
115+
// FIXME(typhoonzero): do we need to wait until the server port is ready?
116+
sleep(5);
117+
118+
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
119+
bool exit_flag = false;
120+
// Record received sparse variables, so that
121+
// we could reset those after execute optimize program
122+
std::vector<framework::Variable *> sparse_vars;
123+
while (!exit_flag) {
124+
// Get from multiple trainers, we don't care about the order in which
125+
// the gradients arrives, just add suffix 0~n and merge the gradient.
126+
rpc_service_->SetCond(0);
127+
size_t recv_var_cnt = 0;
128+
int batch_barrier = 0;
129+
while (batch_barrier != fan_in) {
130+
const detail::ReceivedMessage v = rpc_service_->Get();
131+
auto recv_var_name = v.first;
132+
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
133+
LOG(INFO) << "received terminate message and exit";
134+
exit_flag = true;
158135
break;
159-
}
160-
161-
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
162-
// and this will still work.
163-
164-
// The optimize blocks which have the same parent ID would run parallel
165-
// TODO(Yancey1989): need to use ParallelExecutor for future
166-
int32_t last_parent_blkid = program->Block(1).Parent();
167-
std::vector<size_t> parallel_blkids;
168-
parallel_blkids.push_back(1);
169-
double ts = detail::GetTimestamp();
170-
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
171-
if (program->Block(blkid).Parent() != last_parent_blkid) {
172-
for (size_t idx : parallel_blkids) VLOG(3) << idx;
173-
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
174-
&recv_scope);
175-
parallel_blkids.clear();
176-
last_parent_blkid = program->Block(blkid).Parent();
136+
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
137+
VLOG(3) << "recv batch barrier message";
138+
batch_barrier++;
139+
continue;
140+
} else {
141+
VLOG(3) << "received grad: " << recv_var_name;
142+
recv_var_cnt++;
143+
auto var = v.second->GetVar();
144+
if (var == nullptr) {
145+
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
146+
PADDLE_THROW("Can not find server side var");
147+
}
148+
if (var->IsType<framework::SelectedRows>()) {
149+
sparse_vars.push_back(var);
177150
}
178-
parallel_blkids.push_back(blkid);
179-
}
180-
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
181-
&recv_scope);
182-
183-
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
184-
<< "(ms)";
185-
186-
// Reset the received sparse variables, the sum operator would not
187-
// sum the input sparse variables which rows is empty at the next
188-
// mini-batch.
189-
// TODO(Yancey1989): move the reset action into an operator, we couldn't
190-
// have any hide logic in the operator.
191-
for (auto &var : sparse_vars) {
192-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
193151
}
152+
}
153+
if (exit_flag) {
194154
rpc_service_->SetCond(1);
195-
// NOTE: does not consider barrier request retry in here, we may use
196-
// global barrier id to resolve this.
197-
rpc_service_->WaitClientGet(fan_in);
198-
sparse_vars.clear();
199-
} // while(true)
200-
}
155+
rpc_service_->ShutDown();
156+
break;
157+
}
201158

202-
protected:
203-
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
204-
std::shared_ptr<std::thread> server_thread_;
205-
};
159+
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
160+
// and this will still work.
161+
162+
// The optimize blocks which have the same parent ID would run parallel
163+
// TODO(Yancey1989): need to use ParallelExecutor for future
164+
int32_t last_parent_blkid = program->Block(1).Parent();
165+
std::vector<size_t> parallel_blkids;
166+
parallel_blkids.push_back(1);
167+
double ts = detail::GetTimestamp();
168+
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
169+
if (program->Block(blkid).Parent() != last_parent_blkid) {
170+
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
171+
&recv_scope);
172+
parallel_blkids.clear();
173+
last_parent_blkid = program->Block(blkid).Parent();
174+
}
175+
parallel_blkids.push_back(blkid);
176+
}
177+
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
178+
&recv_scope);
179+
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
180+
181+
// Reset the received sparse variables, the sum operator would not
182+
// sum the input sparse variables which rows is empty at the next
183+
// mini-batch.
184+
// TODO(Yancey1989): move the reset action into an operator, we couldn't
185+
// have any hide logic in the operator.
186+
for (auto &var : sparse_vars) {
187+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
188+
}
189+
rpc_service_->SetCond(1);
190+
// FIXME(typhoonzero): use another condition to sync wait clients get.
191+
rpc_service_->WaitClientGet(fan_in);
192+
sparse_vars.clear();
193+
} // while(true)
194+
}
206195

207196
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
208197
public:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <stdint.h>
18+
#include <ostream>
19+
20+
#include "paddle/fluid/framework/executor.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/threadpool.h"
24+
#include "paddle/fluid/operators/detail/grpc_server.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
constexpr char kOptimizeBlock[] = "OptimizeBlock";
30+
31+
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
32+
33+
class ListenAndServOp : public framework::OperatorBase {
34+
public:
35+
ListenAndServOp(const std::string &type,
36+
const framework::VariableNameMap &inputs,
37+
const framework::VariableNameMap &outputs,
38+
const framework::AttributeMap &attrs);
39+
40+
int GetSelectedPort();
41+
42+
void Stop() override;
43+
44+
void RunImpl(const framework::Scope &scope,
45+
const platform::Place &dev_place) const override;
46+
47+
protected:
48+
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
49+
mutable std::shared_ptr<std::thread> server_thread_;
50+
};
51+
52+
} // namespace operators
53+
} // namespace paddle

0 commit comments

Comments
 (0)