Skip to content

Commit 94eea16

Browse files
committed
fix sendrecv port bind
1 parent 3fd9266 commit 94eea16

File tree

5 files changed

+233
-163
lines changed

5 files changed

+233
-163
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
186186

187187
void AsyncGRPCServer::RunSyncUpdate() {
188188
::grpc::ServerBuilder builder;
189-
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials());
189+
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
190+
&selected_port_);
190191
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
191192
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
192193
builder.RegisterService(&service_);
@@ -196,7 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() {
196197
cq_prefetch_ = builder.AddCompletionQueue();
197198

198199
server_ = builder.BuildAndStart();
199-
LOG(INFO) << "Server listening on " << address_ << std::endl;
200+
LOG(INFO) << "Server listening on " << address_
201+
<< " selected port: " << selected_port_;
200202

201203
std::function<void()> send_register =
202204
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
@@ -242,6 +244,9 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
242244
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
243245
return;
244246
}
247+
while (scope_ == nullptr) {
248+
sleep(0.01);
249+
}
245250
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
246251
&var_recv_queue_, dev_ctx_);
247252
VLOG(4) << "Create RequestSend status:" << send->Status();

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class AsyncGRPCServer final {
6262

6363
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
6464

65+
int GetSelectedPort() { return selected_port_; }
66+
6567
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
6668

6769
void Push(const std::string &msg_name) {
@@ -109,6 +111,7 @@ class AsyncGRPCServer final {
109111
int prefetch_blk_id_;
110112
framework::ProgramDesc *program_;
111113
framework::Executor *executor_;
114+
int selected_port_;
112115
};
113116

114117
}; // namespace detail

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 116 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -12,185 +12,145 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <stdint.h>
1615
#include <ostream>
16+
#include <thread>
1717

18-
#include "paddle/fluid/framework/executor.h"
19-
#include "paddle/fluid/framework/lod_tensor.h"
20-
#include "paddle/fluid/framework/op_registry.h"
21-
#include "paddle/fluid/framework/threadpool.h"
22-
#include "paddle/fluid/operators/detail/grpc_server.h"
18+
#include "paddle/fluid/operators/listen_and_serv_op.h"
2319

2420
namespace paddle {
2521
namespace operators {
2622

27-
constexpr char kOptimizeBlock[] = "OptimizeBlock";
28-
2923
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
3024
service->RunSyncUpdate();
3125
VLOG(4) << "RunServer thread end";
3226
}
3327

34-
static void CreateTensorFromMessageType(framework::Variable *var,
35-
sendrecv::VarType var_type) {
36-
if (var_type == sendrecv::VarType::LOD_TENSOR) {
37-
var->GetMutable<framework::LoDTensor>();
38-
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
39-
var->GetMutable<framework::SelectedRows>();
40-
} else {
41-
PADDLE_THROW(
42-
"VariableMessage type %d is not in "
43-
"[LoDTensor, SelectedRows]",
44-
var_type);
45-
}
28+
ListenAndServOp::ListenAndServOp(const std::string &type,
29+
const framework::VariableNameMap &inputs,
30+
const framework::VariableNameMap &outputs,
31+
const framework::AttributeMap &attrs)
32+
: OperatorBase(type, inputs, outputs, attrs) {}
33+
34+
int ListenAndServOp::GetSelectedPort() {
35+
return rpc_service_->GetSelectedPort();
4636
}
4737

48-
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
49-
framework::Executor *executor,
50-
framework::ProgramDesc *program,
51-
framework::Scope *scope) {
52-
std::vector<std::future<void>> fs;
53-
for (size_t idx : parallel_blkids) {
54-
fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
55-
int run_block = idx; // thread local
56-
try {
57-
executor->Run(*program, scope, run_block, false, false);
58-
} catch (std::exception &e) {
59-
LOG(ERROR) << "run sub program error " << e.what();
60-
}
61-
}));
62-
}
63-
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
38+
void ListenAndServOp::Stop() {
39+
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
40+
server_thread_->join();
6441
}
6542

66-
class ListenAndServOp : public framework::OperatorBase {
67-
public:
68-
ListenAndServOp(const std::string &type,
69-
const framework::VariableNameMap &inputs,
70-
const framework::VariableNameMap &outputs,
71-
const framework::AttributeMap &attrs)
72-
: OperatorBase(type, inputs, outputs, attrs) {
73-
if (!rpc_service_) {
74-
std::string endpoint = Attr<std::string>("endpoint");
75-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
76-
server_thread_.reset(new std::thread(RunServer, rpc_service_));
77-
}
78-
}
43+
void ListenAndServOp::RunImpl(const framework::Scope &scope,
44+
const platform::Place &dev_place) const {
45+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
46+
auto &dev_ctx = *pool.Get(dev_place);
47+
framework::Scope &recv_scope = scope.NewScope();
48+
LOG(INFO) << "created recv scope: " << &recv_scope;
7949

80-
void Stop() override {
81-
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
82-
server_thread_->join();
50+
if (!rpc_service_) {
51+
std::string endpoint = Attr<std::string>("endpoint");
52+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
8353
}
8454

85-
void RunImpl(const framework::Scope &scope,
86-
const platform::Place &dev_place) const override {
87-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
88-
auto &dev_ctx = *pool.Get(dev_place);
89-
framework::Scope &recv_scope = scope.NewScope();
90-
91-
// FIXME(Yancey1989): initialize rpc server with lazy mode.
92-
rpc_service_->SetScope(&recv_scope);
93-
rpc_service_->SetDevCtx(&dev_ctx);
94-
auto ins = Inputs("X");
95-
auto fan_in = Attr<int>("Fanin");
96-
97-
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
98-
auto *program = block->Program();
99-
int num_blocks = program->Size();
100-
PADDLE_ENFORCE_GE(num_blocks, 2,
101-
"server program should have at least 2 blocks");
102-
103-
framework::Executor executor(dev_place);
104-
105-
// TODO(qiao) set proper fields for table lookup and update
106-
rpc_service_->SetExecutor(&executor);
107-
rpc_service_->SetPrefetchBlkdId(0);
108-
rpc_service_->SetProgram(program);
109-
110-
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
111-
bool exit_flag = false;
112-
// Record received sparse variables, so that
113-
// we could reset those after execute optimize program
114-
std::vector<framework::Variable *> sparse_vars;
115-
while (!exit_flag) {
116-
// Get from multiple trainers, we don't care about the order in which
117-
// the gradients arrives, just add suffix 0~n and merge the gradient.
118-
rpc_service_->SetCond(0);
119-
size_t recv_var_cnt = 0;
120-
int batch_barrier = 0;
121-
while (batch_barrier != fan_in) {
122-
const detail::ReceivedMessage v = rpc_service_->Get();
123-
auto recv_var_name = v.first;
124-
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
125-
LOG(INFO) << "received terminate message and exit";
126-
exit_flag = true;
127-
break;
128-
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
129-
VLOG(3) << "recv batch barrier message";
130-
batch_barrier++;
131-
continue;
132-
} else {
133-
VLOG(3) << "received grad: " << recv_var_name;
134-
recv_var_cnt++;
135-
auto var = v.second->GetVar();
136-
if (var == nullptr) {
137-
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
138-
PADDLE_THROW("Can not find server side var");
139-
}
140-
if (var->IsType<framework::SelectedRows>()) {
141-
sparse_vars.push_back(var);
142-
}
143-
}
144-
}
145-
if (exit_flag) {
146-
rpc_service_->SetCond(1);
147-
rpc_service_->ShutDown();
55+
auto ins = Inputs("X");
56+
auto fan_in = Attr<int>("Fanin");
57+
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
58+
auto *program = block->Program();
59+
size_t num_blocks = program->Size();
60+
PADDLE_ENFORCE_GE(num_blocks, 2,
61+
"server program should have at least 2 blocks");
62+
63+
framework::Executor executor(dev_place);
64+
65+
// FIXME(Yancey1989): initialize rpc server with lazy mode.
66+
rpc_service_->SetScope(&recv_scope);
67+
rpc_service_->SetDevCtx(&dev_ctx);
68+
// TODO(qiao) set proper fields for table lookup and update
69+
rpc_service_->SetExecutor(&executor);
70+
rpc_service_->SetPrefetchBlkdId(0);
71+
rpc_service_->SetProgram(program);
72+
// start the server listening after all member initialized.
73+
server_thread_.reset(new std::thread(RunServer, rpc_service_));
74+
// FIXME(typhoonzero): do we need to wait until the server port is ready?
75+
sleep(5);
76+
77+
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
78+
bool exit_flag = false;
79+
// Record received sparse variables, so that
80+
// we could reset those after execute optimize program
81+
std::vector<framework::Variable *> sparse_vars;
82+
while (!exit_flag) {
83+
// Get from multiple trainers, we don't care about the order in which
84+
// the gradients arrives, just add suffix 0~n and merge the gradient.
85+
rpc_service_->SetCond(0);
86+
size_t recv_var_cnt = 0;
87+
int batch_barrier = 0;
88+
while (batch_barrier != fan_in) {
89+
const detail::ReceivedMessage v = rpc_service_->Get();
90+
auto recv_var_name = v.first;
91+
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
92+
LOG(INFO) << "received terminate message and exit";
93+
exit_flag = true;
14894
break;
149-
}
150-
151-
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
152-
// and this will still work.
153-
154-
// The optimize blocks which have the same parent ID would run parallel
155-
// TODO(Yancey1989): need to use ParallelExecutor for future
156-
size_t last_parent_blkid = program->Block(1).Parent();
157-
std::vector<size_t> parallel_blkids;
158-
parallel_blkids.push_back(1);
159-
double ts = detail::GetTimestamp();
160-
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
161-
if (program->Block(blkid).Parent() != last_parent_blkid) {
162-
for (size_t idx : parallel_blkids) VLOG(3) << idx;
163-
ParallelExecuteBlocks(parallel_blkids, &executor, program,
164-
&recv_scope);
165-
parallel_blkids.clear();
166-
last_parent_blkid = program->Block(blkid).Parent();
95+
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
96+
VLOG(3) << "recv batch barrier message";
97+
batch_barrier++;
98+
continue;
99+
} else {
100+
VLOG(3) << "received grad: " << recv_var_name;
101+
recv_var_cnt++;
102+
auto var = v.second->GetVar();
103+
if (var == nullptr) {
104+
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
105+
PADDLE_THROW("Can not find server side var");
106+
}
107+
if (var->IsType<framework::SelectedRows>()) {
108+
sparse_vars.push_back(var);
167109
}
168-
parallel_blkids.push_back(blkid);
169-
}
170-
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
171-
172-
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
173-
<< "(ms)";
174-
175-
// Reset the received sparse variables, the sum operator would not
176-
// sum the input sparse variables which rows is empty at the next
177-
// mini-batch.
178-
// TODO(Yancey1989): move the reset action into an operator, we couldn't
179-
// have any hide logic in the operator.
180-
for (auto &var : sparse_vars) {
181-
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
182110
}
111+
}
112+
if (exit_flag) {
183113
rpc_service_->SetCond(1);
184-
// FIXME(typhoonzero): use another condition to sync wait clients get.
185-
rpc_service_->WaitClientGet(fan_in);
186-
sparse_vars.clear();
187-
} // while(true)
188-
}
114+
rpc_service_->ShutDown();
115+
break;
116+
}
189117

190-
protected:
191-
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
192-
std::shared_ptr<std::thread> server_thread_;
193-
};
118+
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
119+
// and this will still work.
120+
121+
// The optimize blocks which have the same parent ID would run parallel
122+
// TODO(Yancey1989): need to use ParallelExecutor for future
123+
int32_t last_parent_blkid = program->Block(1).Parent();
124+
std::vector<size_t> parallel_blkids;
125+
parallel_blkids.push_back(1);
126+
double ts = detail::GetTimestamp();
127+
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
128+
if (program->Block(blkid).Parent() != last_parent_blkid) {
129+
for (size_t idx : parallel_blkids) VLOG(3) << idx;
130+
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
131+
parallel_blkids.clear();
132+
last_parent_blkid = program->Block(blkid).Parent();
133+
}
134+
parallel_blkids.push_back(blkid);
135+
}
136+
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
137+
138+
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
139+
140+
// Reset the received sparse variables, the sum operator would not
141+
// sum the input sparse variables which rows is empty at the next
142+
// mini-batch.
143+
// TODO(Yancey1989): move the reset action into an operator, we couldn't
144+
// have any hide logic in the operator.
145+
for (auto &var : sparse_vars) {
146+
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
147+
}
148+
rpc_service_->SetCond(1);
149+
// FIXME(typhoonzero): use another condition to sync wait clients get.
150+
rpc_service_->WaitClientGet(fan_in);
151+
sparse_vars.clear();
152+
} // while(true)
153+
}
194154

195155
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
196156
public:

0 commit comments

Comments
 (0)