Skip to content

Commit 8f7c773

Browse files
committed
refine listen_and_serv_op
1 parent cec4e6e commit 8f7c773

File tree

4 files changed

+75
-66
lines changed

4 files changed

+75
-66
lines changed

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class AsyncGRPCServer final {
6767
prefetch_ctx_ = prepared;
6868
}
6969

70-
int GetSelectedPort() { return selected_port_; }
70+
int GetSelectedPort() const { return selected_port_; }
7171

7272
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
7373

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,6 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
2727
VLOG(4) << "RunServer thread end";
2828
}
2929

30-
static void CreateTensorFromMessageType(framework::Variable *var,
31-
sendrecv::VarType var_type) {
32-
if (var_type == sendrecv::VarType::LOD_TENSOR) {
33-
var->GetMutable<framework::LoDTensor>();
34-
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
35-
var->GetMutable<framework::SelectedRows>();
36-
} else {
37-
PADDLE_THROW(
38-
"VariableMessage type %d is not in "
39-
"[LoDTensor, SelectedRows]",
40-
var_type);
41-
}
42-
}
43-
4430
static void ParallelExecuteBlocks(
4531
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
4632
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
@@ -77,59 +63,37 @@ void ListenAndServOp::Stop() {
7763
server_thread_->join();
7864
}
7965

80-
void ListenAndServOp::RunImpl(const framework::Scope &scope,
81-
const platform::Place &dev_place) const {
82-
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
83-
auto &dev_ctx = *pool.Get(dev_place);
84-
framework::Scope &recv_scope = scope.NewScope();
85-
86-
if (!rpc_service_) {
87-
std::string endpoint = Attr<std::string>("endpoint");
88-
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
89-
}
66+
void ListenAndServOp::PreparePrefetchCtx(
67+
framework::Executor *executor, framework::BlockDesc *prefetch_block,
68+
framework::ProgramDesc *program) const {
69+
// TODO(qiao) set proper fields for table lookup and update
70+
rpc_service_->SetExecutor(executor);
71+
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
72+
auto prefetch_prepared = executor->Prepare(*program, prefetch_block->ID());
73+
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
74+
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
75+
prefetch_prepared.release();
76+
}
9077

91-
auto ins = Inputs("X");
78+
void ListenAndServOp::RunSyncUpdate(
79+
framework::Executor *executor, framework::ProgramDesc *program,
80+
framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const {
9281
auto fan_in = Attr<int>("Fanin");
93-
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
94-
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
95-
auto *program = optimize_block->Program();
82+
9683
size_t num_blocks = program->Size();
9784
PADDLE_ENFORCE_GE(num_blocks, 2,
9885
"server program should have at least 2 blocks");
9986

100-
framework::Executor executor(dev_place);
10187
std::vector<int> block_list;
10288
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
103-
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
104-
block_list.push_back(blkid);
105-
}
89+
block_list.push_back(blkid);
10690
}
107-
auto optimize_prepared = executor.Prepare(*program, block_list);
91+
auto optimize_prepared = executor->Prepare(*program, block_list);
10892
// Insert placeholder for block0 which holds current op itself.
10993
optimize_prepared.insert(
11094
optimize_prepared.begin(),
11195
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
11296

113-
rpc_service_->SetScope(&recv_scope);
114-
rpc_service_->SetDevCtx(&dev_ctx);
115-
// TODO(qiao) set proper fields for table lookup and update
116-
rpc_service_->SetExecutor(&executor);
117-
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
118-
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
119-
rpc_service_->SetPrefetchBlkdId(prefetch_block->ID());
120-
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
121-
prefetch_prepared.release();
122-
rpc_service_->SetProgram(program);
123-
// start the server listening after all member initialized.
124-
server_thread_.reset(new std::thread(RunServer, rpc_service_));
125-
VLOG(3) << "wait server thread to become ready...";
126-
sleep(5);
127-
// Write to a file of server selected port for python use.
128-
std::ofstream port_file;
129-
port_file.open("/tmp/paddle.selected_port");
130-
port_file << rpc_service_->GetSelectedPort();
131-
port_file.close();
132-
13397
bool exit_flag = false;
13498
// Record received sparse variables, so that
13599
// we could reset those after execute optimize program
@@ -170,7 +134,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
170134
break;
171135
}
172136

173-
// NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads
137+
// NOTE: if is_gpu_place, CUDA kernels are launch by multiple threads
174138
// and this will still work.
175139

176140
// The optimize blocks which have the same parent ID would run parallel
@@ -182,16 +146,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
182146
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
183147
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
184148
if (program->Block(blkid).Parent() != last_parent_blkid) {
185-
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
186-
program, &recv_scope);
149+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
150+
program, recv_scope);
187151
parallel_blkids.clear();
188152
last_parent_blkid = program->Block(blkid).Parent();
189153
}
190154
parallel_blkids.push_back(blkid);
191155
}
192156
}
193-
ParallelExecuteBlocks(parallel_blkids, &executor, optimize_prepared,
194-
program, &recv_scope);
157+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
158+
recv_scope);
195159
VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
196160

197161
// Reset the received sparse variables, the sum operator would not
@@ -209,6 +173,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
209173
} // while(true)
210174
}
211175

176+
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
177+
std::ofstream port_file;
178+
port_file.open("/tmp/paddle.selected_port");
179+
port_file << rpc_service->GetSelectedPort();
180+
port_file.close();
181+
}
182+
183+
void ListenAndServOp::RunImpl(const framework::Scope &scope,
184+
const platform::Place &dev_place) const {
185+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
186+
auto &dev_ctx = *pool.Get(dev_place);
187+
framework::Scope &recv_scope = scope.NewScope();
188+
189+
PADDLE_ENFORCE(!rpc_service_);
190+
std::string endpoint = Attr<std::string>("endpoint");
191+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
192+
193+
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
194+
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
195+
auto *program = optimize_block->Program();
196+
framework::Executor executor(dev_place);
197+
198+
// prepare rpc_service
199+
rpc_service_->SetScope(&recv_scope);
200+
rpc_service_->SetDevCtx(&dev_ctx);
201+
rpc_service_->SetProgram(program);
202+
PreparePrefetchCtx(&executor, prefetch_block, program);
203+
// start the server listening after all member initialized.
204+
server_thread_.reset(new std::thread(RunServer, rpc_service_));
205+
VLOG(3) << "wait server thread to become ready...";
206+
sleep(5);
207+
// Write to a file of server selected port for python use.
208+
SavePort(rpc_service_);
209+
RunSyncUpdate(&executor, program, &recv_scope, prefetch_block);
210+
}
211+
212212
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
213213
public:
214214
ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,26 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
3434

3535
class ListenAndServOp : public framework::OperatorBase {
3636
public:
37-
ListenAndServOp(const std::string &type,
38-
const framework::VariableNameMap &inputs,
39-
const framework::VariableNameMap &outputs,
40-
const framework::AttributeMap &attrs);
37+
ListenAndServOp(const std::string& type,
38+
const framework::VariableNameMap& inputs,
39+
const framework::VariableNameMap& outputs,
40+
const framework::AttributeMap& attrs);
4141

4242
int GetSelectedPort() const;
4343

44+
void PreparePrefetchCtx(framework::Executor* executor,
45+
framework::BlockDesc* prefetch_block,
46+
framework::ProgramDesc* program) const;
47+
48+
void RunSyncUpdate(framework::Executor* executor,
49+
framework::ProgramDesc* program,
50+
framework::Scope* recv_scope,
51+
framework::BlockDesc* prefetch_block) const;
52+
4453
void Stop() override;
4554

46-
void RunImpl(const framework::Scope &scope,
47-
const platform::Place &dev_place) const override;
55+
void RunImpl(const framework::Scope& scope,
56+
const platform::Place& dev_place) const override;
4857

4958
protected:
5059
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;

paddle/fluid/operators/send_recv_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ void StartServerNet(bool is_sparse) {
127127
const auto &root_block = program.Block(0);
128128
auto *optimize_block = program.AppendBlock(root_block);
129129
auto *prefetch_block = program.AppendBlock(root_block);
130-
// X for server side tensors, RX for received tensers, must be of same shape.
130+
// X for server side tensors, RX for received tensors, must be of same shape.
131131
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
132132

133133
f::AttributeMap attrs;

0 commit comments

Comments
 (0)