Skip to content

Commit fb8c1cf

Browse files
authored
Merge pull request #9377 from typhoonzero/prepare_pserver_executor
prepare pserver executor
2 parents 66e0aed + 1f6e044 commit fb8c1cf

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,21 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
279279
return std::unique_ptr<ExecutorPrepareContext>(ctx);
280280
}
281281

282+
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
283+
const ProgramDesc& program, const std::vector<int>& block_ids) {
284+
std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
285+
for (auto& bid : block_ids) {
286+
auto* ctx = new ExecutorPrepareContext(program, bid);
287+
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
288+
auto& block = program.Block(bid);
289+
for (auto& op_desc : block.AllOps()) {
290+
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
291+
}
292+
result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
293+
}
294+
return result;
295+
}
296+
282297
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
283298
bool create_local_scope, bool create_vars) {
284299
auto& block = ctx->prog_.Block(ctx->block_id_);

paddle/fluid/framework/executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class Executor {
6161
static std::unique_ptr<ExecutorPrepareContext> Prepare(
6262
const ProgramDesc& program, int block_id);
6363

64+
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
65+
const ProgramDesc& program, const std::vector<int>& block_ids);
66+
6467
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
6568
bool create_local_scope = true,
6669
bool create_vars = true);

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,23 @@ static void CreateTensorFromMessageType(framework::Variable *var,
4545
}
4646
}
4747

48-
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
49-
framework::Executor *executor,
50-
framework::ProgramDesc *program,
51-
framework::Scope *scope) {
48+
static void ParallelExecuteBlocks(
49+
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
50+
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
51+
&prepared,
52+
framework::ProgramDesc *program, framework::Scope *scope) {
5253
std::vector<std::future<void>> fs;
5354
for (size_t idx : parallel_blkids) {
54-
fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
55-
int run_block = idx; // thread local
56-
try {
57-
executor->Run(*program, scope, run_block, false, false);
58-
} catch (std::exception &e) {
59-
LOG(ERROR) << "run sub program error " << e.what();
60-
}
61-
}));
55+
fs.push_back(
56+
framework::Async([&executor, &prepared, &program, &scope, idx]() {
57+
int run_block = idx; // thread local
58+
try {
59+
executor->RunPreparedContext(prepared[run_block].get(), scope,
60+
false, false);
61+
} catch (std::exception &e) {
62+
LOG(ERROR) << "run sub program error " << e.what();
63+
}
64+
}));
6265
}
6366
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
6467
}
@@ -101,6 +104,13 @@ class ListenAndServOp : public framework::OperatorBase {
101104
"server program should have at least 2 blocks");
102105

103106
framework::Executor executor(dev_place);
107+
std::vector<int> block_list;
108+
for (size_t blkid = 1; blkid < num_blocks; ++blkid)
109+
block_list.push_back(blkid);
110+
auto prepared = executor.Prepare(*program, block_list);
111+
prepared.insert(
112+
prepared.begin(),
113+
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
104114

105115
// TODO(qiao) set proper fields for table lookup and update
106116
rpc_service_->SetExecutor(&executor);
@@ -160,14 +170,15 @@ class ListenAndServOp : public framework::OperatorBase {
160170
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
161171
if (program->Block(blkid).Parent() != last_parent_blkid) {
162172
for (size_t idx : parallel_blkids) VLOG(3) << idx;
163-
ParallelExecuteBlocks(parallel_blkids, &executor, program,
173+
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
164174
&recv_scope);
165175
parallel_blkids.clear();
166176
last_parent_blkid = program->Block(blkid).Parent();
167177
}
168178
parallel_blkids.push_back(blkid);
169179
}
170-
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
180+
ParallelExecuteBlocks(parallel_blkids, &executor, prepared, program,
181+
&recv_scope);
171182

172183
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
173184
<< "(ms)";
@@ -181,7 +192,8 @@ class ListenAndServOp : public framework::OperatorBase {
181192
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
182193
}
183194
rpc_service_->SetCond(1);
184-
// FIXME(typhoonzero): use another condition to sync wait clients get.
195+
// NOTE: does not consider barrier request retry in here, we may use
196+
// global barrier id to resolve this.
185197
rpc_service_->WaitClientGet(fan_in);
186198
sparse_vars.clear();
187199
} // while(true)

0 commit comments

Comments
 (0)