Skip to content

Commit 7a993ee

Browse files
authored
Merge pull request #10080 from jacquesqiao/refine-listen-and-serve-op
Refine listen and serve op
2 parents f2e400d + 0f5a9cc commit 7a993ee

File tree

4 files changed

+68
-70
lines changed

4 files changed

+68
-70
lines changed

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,13 @@ class AsyncGRPCServer final {
5959

6060
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
6161

62-
void SetPrefetchBlkdId(int blkid) { prefetch_blk_id_ = blkid; }
63-
6462
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
6563

6664
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
6765
prefetch_ctx_ = prepared;
6866
}
6967

70-
int GetSelectedPort() { return selected_port_; }
68+
int GetSelectedPort() const { return selected_port_; }
7169

7270
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
7371

@@ -114,7 +112,6 @@ class AsyncGRPCServer final {
114112
std::unique_ptr<std::thread> t_get_;
115113
std::unique_ptr<std::thread> t_prefetch_;
116114

117-
int prefetch_blk_id_;
118115
framework::ExecutorPrepareContext *prefetch_ctx_;
119116
framework::ProgramDesc *program_;
120117
framework::Executor *executor_;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
2727
VLOG(4) << "RunServer thread end";
2828
}
2929

30-
static void CreateTensorFromMessageType(framework::Variable *var,
31-
sendrecv::VarType var_type) {
32-
if (var_type == sendrecv::VarType::LOD_TENSOR) {
33-
var->GetMutable<framework::LoDTensor>();
34-
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
35-
var->GetMutable<framework::SelectedRows>();
36-
} else {
37-
PADDLE_THROW(
38-
"VariableMessage type %d is not in "
39-
"[LoDTensor, SelectedRows]",
40-
var_type);
41-
}
42-
}
43-
4430
static void ParallelExecuteBlocks(
4531
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
4632
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -62,6 +48,13 @@ static void ParallelExecuteBlocks(
6248
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
6349
}
6450

51+
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
52+
std::ofstream port_file;
53+
port_file.open("/tmp/paddle.selected_port");
54+
port_file << rpc_service->GetSelectedPort();
55+
port_file.close();
56+
}
57+
6558
ListenAndServOp::ListenAndServOp(const std::string &type,
6659
const framework::VariableNameMap &inputs,
6760
const framework::VariableNameMap &outputs,
@@ -77,59 +70,26 @@ void ListenAndServOp::Stop() {
7770
server_thread_->join();
7871
}
7972

80-
void ListenAndServOp::RunImpl(const framework::Scope &scope,
81-
const platform::Place &dev_place) const {
82-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
83-
auto &dev_ctx = *pool.Get(dev_place);
84-
framework::Scope &recv_scope = scope.NewScope();
85-
86-
if (!rpc_service_) {
87-
std::string endpoint = Attr<std::string>("endpoint");
88-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
89-
}
90-
91-
auto ins = Inputs("X");
73+
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
74+
framework::ProgramDesc *program,
75+
framework::Scope *recv_scope,
76+
framework::BlockDesc *prefetch_block) const {
9277
auto fan_in = Attr<int>("Fanin");
93-
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
94-
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
95-
auto *program = optimize_block->Program();
78+
9679
size_t num_blocks = program->Size();
9780
PADDLE_ENFORCE_GE(num_blocks, 2,
9881
"server program should have at least 2 blocks");
9982

100-
framework::Executor executor(dev_place);
10183
std::vector<int> block_list;
10284
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
103-
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
104-
block_list.push_back(blkid);
105-
}
85+
block_list.push_back(blkid);
10686
}
107-
auto optimize_prepared = executor.Prepare(*program, block_list);
87+
auto optimize_prepared = executor->Prepare(*program, block_list);
10888
// Insert placeholder for block0 which holds current op itself.
10989
optimize_prepared.insert(
11090
optimize_prepared.begin(),
11191
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
11292

113-
rpc_service_->SetScope(&recv_scope);
114-
rpc_service_->SetDevCtx(&dev_ctx);
115-
// TODO(qiao) set proper fields for table lookup and update
116-
rpc_service_->SetExecutor(&executor);
117-
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
118-
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
119-
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
120-
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
121-
prefetch_prepared.release();
122-
rpc_service_->SetProgram(program);
123-
// start the server listening after all member initialized.
124-
server_thread_.reset(new std::thread(RunServer, rpc_service_));
125-
VLOG(3) << "wait server thread to become ready...";
126-
sleep(5);
127-
// Write to a file of server selected port for python use.
128-
std::ofstream port_file;
129-
port_file.open("/tmp/paddle.selected_port");
130-
port_file << rpc_service_->GetSelectedPort();
131-
port_file.close();
132-
13393
bool exit_flag = false;
13494
// Record received sparse variables, so that
13595
// we could reset those after execute optimize program
@@ -170,7 +130,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
170130
break;
171131
}
172132

173-
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
133+
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
174134
// and this will still work.
175135

176136
// The optimize blocks which have the same parent ID would run parallel
@@ -182,16 +142,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
182142
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
183143
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
184144
if (program->Block(blkid).Parent() != last_parent_blkid) {
185-
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
186-
program, &recv_scope);
145+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
146+
program, recv_scope);
187147
parallel_blkids.clear();
188148
last_parent_blkid = program->Block(blkid).Parent();
189149
}
190150
parallel_blkids.push_back(blkid);
191151
}
192152
}
193-
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
194-
program, &recv_scope);
153+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
154+
recv_scope);
195155
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
196156

197157
// Reset the received sparse variables, the sum operator would not
@@ -209,6 +169,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
209169
} // while(true)
210170
}
211171

172+
void ListenAndServOp::RunImpl(const framework::Scope &scope,
173+
const platform::Place &dev_place) const {
174+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
175+
auto &dev_ctx = *pool.Get(dev_place);
176+
framework::Scope &recv_scope = scope.NewScope();
177+
178+
PADDLE_ENFORCE(!rpc_service_);
179+
std::string endpoint = Attr<std::string>("endpoint");
180+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
181+
182+
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
183+
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
184+
auto *program = optimize_block->Program();
185+
framework::Executor executor(dev_place);
186+
187+
// prepare rpc_service
188+
rpc_service_->SetScope(&recv_scope);
189+
rpc_service_->SetDevCtx(&dev_ctx);
190+
rpc_service_->SetProgram(program);
191+
rpc_service_->SetExecutor(&executor);
192+
193+
// prepare for prefetch
194+
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
195+
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
196+
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
197+
prefetch_prepared.release();
198+
199+
// start the server listening after all member initialized.
200+
server_thread_.reset(new std::thread(RunServer, rpc_service_));
201+
VLOG(3) << "wait server thread to become ready...";
202+
sleep(5);
203+
// Write to a file of server selected port for python use.
204+
SavePort(rpc_service_);
205+
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
206+
}
207+
212208
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
213209
public:
214210
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,22 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
3434

3535
class ListenAndServOp : public framework::OperatorBase {
3636
public:
37-
ListenAndServOp(const std::string &type,
38-
const framework::VariableNameMap &inputs,
39-
const framework::VariableNameMap &outputs,
40-
const framework::AttributeMap &attrs);
37+
ListenAndServOp(const std::string& type,
38+
const framework::VariableNameMap& inputs,
39+
const framework::VariableNameMap& outputs,
40+
const framework::AttributeMap& attrs);
4141

4242
int GetSelectedPort() const;
4343

44+
void RunSyncLoop(framework::Executor* executor,
45+
framework::ProgramDesc* program,
46+
framework::Scope* recv_scope,
47+
framework::BlockDesc* prefetch_block) const;
48+
4449
void Stop() override;
4550

46-
void RunImpl(const framework::Scope &scope,
47-
const platform::Place &dev_place) const override;
51+
void RunImpl(const framework::Scope& scope,
52+
const platform::Place& dev_place) const override;
4853

4954
protected:
5055
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;

paddle/fluid/operators/send_recv_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) {
127127
const auto &root_block = program.Block(0);
128128
auto *optimize_block = program.AppendBlock(root_block);
129129
auto *prefetch_block = program.AppendBlock(root_block);
130-
// X for server side tensors, RX for received tensers, must be of same shape.
130+
// X for server side tensors, RX for received tensors, must be of same shape.
131131
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
132132

133133
f::AttributeMap attrs;

0 commit comments

Comments
 (0)