Skip to content

Commit 431491a

Browse files
authored
Merge pull request #11366 from jacquesqiao/refine-prefetch
Refine prefetch
2 parents 34865f2 + 2b9ff39 commit 431491a

File tree

9 files changed

+183
-117
lines changed

9 files changed

+183
-117
lines changed

paddle/fluid/framework/details/ssa_graph_checker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
namespace paddle {
2020
namespace framework {
2121
namespace details {
22-
class SSAGraph;
22+
struct SSAGraph;
2323

2424
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
2525
public:

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,18 @@ class RequestPrefetch final : public RequestBase {
162162

163163
void Process() override {
164164
// prefetch process...
165-
std::string varname = request_->OutVarname();
166-
VLOG(3) << "RequestPrefetch " << varname;
165+
std::string in_var_name = request_->Varname();
166+
std::string out_var_name = request_->OutVarname();
167+
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
168+
<< " out_var_name: " << out_var_name;
167169

168170
auto scope = request_->GetMutableLocalScope();
169-
auto invar = scope->FindVar(varname);
170-
framework::Variable* outvar = nullptr;
171+
auto invar = scope->FindVar(in_var_name);
172+
framework::Variable* outvar = scope->FindVar(out_var_name);
171173

172-
request_handler_->Handle(varname, scope, invar, &outvar);
174+
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name);
173175

174-
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
176+
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
175177
&reply_);
176178
Finish(reply_, &responder_);
177179
}
@@ -287,7 +289,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
287289
} else if (rpc_name == kRequestPrefetch) {
288290
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
289291
} else {
290-
PADDLE_ENFORCE(false, "not surpported rpc");
292+
PADDLE_ENFORCE(false, "not supported rpc");
291293
}
292294

293295
reqs[req_id] = b;

paddle/fluid/operators/detail/request_handler.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@ class RequestHandler {
6161
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
6262
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
6363
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
64+
65+
// Used for dist lookup table prefetch
6466
void SetPrefetchPreparedCtx(
65-
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
66-
prefetch_ctx_.reset(prepared.release());
67+
std::unordered_map<
68+
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
69+
prefetch_var_name_to_prepared_ctx_ = g;
6770
}
6871

6972
// Used for async.
@@ -79,9 +82,6 @@ class RequestHandler {
7982
bool sync_mode() { return sync_mode_; }
8083
framework::Scope* scope() { return scope_; }
8184
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
82-
framework::ExecutorPrepareContext* prefetch_ctx() {
83-
return prefetch_ctx_.get();
84-
}
8585
framework::ProgramDesc* program() { return program_; }
8686
framework::Executor* executor() { return executor_; }
8787

@@ -100,8 +100,8 @@ class RequestHandler {
100100
// *request_handler_->dev_ctx(), &reply_);
101101
// }
102102
virtual bool Handle(const std::string& varname, framework::Scope* scope,
103-
framework::Variable* var,
104-
framework::Variable** outvar) = 0;
103+
framework::Variable* var, framework::Variable** outvar,
104+
const std::string& out_var_name = "") = 0;
105105

106106
protected:
107107
const bool sync_mode_;
@@ -110,12 +110,17 @@ class RequestHandler {
110110
framework::Executor* executor_;
111111
framework::Scope* scope_;
112112
framework::ProgramDesc* program_;
113-
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
113+
114+
// used for distribute lookup table prefetch
115+
std::unordered_map<std::string,
116+
std::shared_ptr<framework::ExecutorPrepareContext>>*
117+
prefetch_var_name_to_prepared_ctx_;
114118

115119
// Used for async.
116120
std::unordered_map<std::string,
117121
std::shared_ptr<framework::ExecutorPrepareContext>>*
118122
grad_to_prepared_ctx_;
123+
119124
RPCServer* rpc_server_;
120125
};
121126

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ namespace detail {
3030
bool RequestSendHandler::Handle(const std::string& varname,
3131
framework::Scope* scope,
3232
framework::Variable* invar,
33-
framework::Variable** outvar) {
33+
framework::Variable** outvar,
34+
const std::string& out_var_name) {
3435
VLOG(4) << "RequestSendHandler:" << varname;
3536

3637
// Async
@@ -82,7 +83,8 @@ void RequestSendHandler::ResetSparseVarRecorder() {
8283
bool RequestGetHandler::Handle(const std::string& varname,
8384
framework::Scope* scope,
8485
framework::Variable* invar,
85-
framework::Variable** outvar) {
86+
framework::Variable** outvar,
87+
const std::string& out_var_name) {
8688
VLOG(4) << "RequestGetHandler:" << varname;
8789

8890
if (varname != FETCH_BARRIER_MESSAGE) {
@@ -105,13 +107,14 @@ bool RequestGetHandler::Handle(const std::string& varname,
105107
bool RequestPrefetchHandler::Handle(const std::string& varname,
106108
framework::Scope* scope,
107109
framework::Variable* invar,
108-
framework::Variable** outvar) {
110+
framework::Variable** outvar,
111+
const std::string& out_var_name) {
109112
VLOG(4) << "RequestPrefetchHandler " << varname;
110113

111-
auto var_desc = program_->Block(0).FindVar(varname);
112-
*outvar = scope->FindVar(varname);
114+
auto var_desc = program_->Block(0).FindVar(out_var_name);
113115
InitializeVariable(*outvar, var_desc->GetType());
114-
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
116+
executor_->RunPreparedContext(
117+
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
115118

116119
return true;
117120
}

paddle/fluid/operators/detail/request_handler_impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class RequestSendHandler final : public RequestHandler {
3939
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
4040
virtual ~RequestSendHandler() {}
4141
bool Handle(const std::string& varname, framework::Scope* scope,
42-
framework::Variable* var, framework::Variable** outvar) override;
42+
framework::Variable* var, framework::Variable** outvar,
43+
const std::string& out_var_name = "") override;
4344
void ResetSparseVarRecorder();
4445

4546
private:
@@ -52,15 +53,17 @@ class RequestGetHandler final : public RequestHandler {
5253
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
5354
virtual ~RequestGetHandler() {}
5455
bool Handle(const std::string& varname, framework::Scope* scope,
55-
framework::Variable* var, framework::Variable** outvar) override;
56+
framework::Variable* var, framework::Variable** outvar,
57+
const std::string& out_var_name = "") override;
5658
};
5759

5860
class RequestPrefetchHandler final : public RequestHandler {
5961
public:
6062
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
6163
virtual ~RequestPrefetchHandler() {}
6264
bool Handle(const std::string& varname, framework::Scope* scope,
63-
framework::Variable* var, framework::Variable** outvar) override;
65+
framework::Variable* var, framework::Variable** outvar,
66+
const std::string& out_var_name = "") override;
6467
};
6568

6669
} // namespace detail

paddle/fluid/operators/detail/rpc_server_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,17 @@ void StartServer() {
9898
framework::Executor exe(place);
9999
platform::CPUDeviceContext ctx(place);
100100
auto* block = AppendPrefetchBlcok(&program);
101-
auto prepared = exe.Prepare(program, block->ID());
101+
std::string in_var_name("ids");
102+
std::vector<int> prefetch_block_ids{block->ID()};
103+
auto prepared = exe.Prepare(program, prefetch_block_ids);
102104
InitTensorsOnServer(&scope, &place, 10);
103105

106+
std::unordered_map<std::string,
107+
std::shared_ptr<framework::ExecutorPrepareContext>>
108+
prefetch_var_name_to_prepared;
109+
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
104110
g_req_handler->SetProgram(&program);
105-
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
111+
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
106112
g_req_handler->SetDevCtx(&ctx);
107113
g_req_handler->SetScope(&scope);
108114
g_req_handler->SetExecutor(&exe);

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,22 @@ static int64_t GetTimestamp() {
9696
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
9797
}
9898

99-
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
100-
framework::ProgramDesc *program,
101-
framework::Scope *recv_scope,
102-
framework::BlockDesc *prefetch_block) const {
99+
void ListenAndServOp::RunSyncLoop(
100+
framework::Executor *executor, framework::ProgramDesc *program,
101+
framework::Scope *recv_scope,
102+
const std::vector<int> &prefetch_block_id_list) const {
103103
size_t num_blocks = program->Size();
104104
PADDLE_ENFORCE_GE(num_blocks, 2,
105105
"server program should have at least 2 blocks");
106106

107-
std::vector<int> block_list;
108-
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
109-
block_list.push_back(blkid);
107+
std::vector<int> optimize_block_id_list;
108+
for (int blkid = 1; blkid < num_blocks; ++blkid) {
109+
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
110+
blkid) == prefetch_block_id_list.end()) {
111+
optimize_block_id_list.push_back(blkid);
112+
}
110113
}
111-
auto optimize_prepared = executor->Prepare(*program, block_list);
114+
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
112115
// Insert placeholder for block0 which holds current op itself.
113116
optimize_prepared.insert(
114117
optimize_prepared.begin(),
@@ -135,16 +138,17 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
135138
std::vector<size_t> parallel_blkids;
136139
parallel_blkids.push_back(1);
137140
double ts = GetTimestamp();
138-
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
139-
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
140-
if (program->Block(blkid).Parent() != last_parent_blkid) {
141-
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
142-
program, recv_scope);
143-
parallel_blkids.clear();
144-
last_parent_blkid = program->Block(blkid).Parent();
145-
}
146-
parallel_blkids.push_back(blkid);
141+
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
142+
// skip the first optimize block because it is already in the
143+
// parallel_blkids.
144+
int blkid = optimize_block_id_list[i];
145+
if (program->Block(blkid).Parent() != last_parent_blkid) {
146+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
147+
program, recv_scope);
148+
parallel_blkids.clear();
149+
last_parent_blkid = program->Block(blkid).Parent();
147150
}
151+
parallel_blkids.push_back(blkid);
148152
}
149153
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
150154
recv_scope);
@@ -210,18 +214,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
210214
} // while(true)
211215
}
212216

213-
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
214-
platform::DeviceContext *dev_ctx,
215-
framework::Executor *executor,
216-
framework::ProgramDesc *program,
217-
framework::ExecutorPrepareContext *prefetch_ctx,
218-
detail::RPCServer *rpc_server) {
217+
static void FillRequestCtx(
218+
detail::RequestHandler *h, framework::Scope *scope,
219+
platform::DeviceContext *dev_ctx, framework::Executor *executor,
220+
framework::ProgramDesc *program,
221+
std::unordered_map<std::string,
222+
std::shared_ptr<framework::ExecutorPrepareContext>>
223+
*prefetch_ctx,
224+
detail::RPCServer *rpc_server) {
219225
h->SetScope(scope);
220226
h->SetDevCtx(dev_ctx);
221227
h->SetExecutor(executor);
222228
h->SetProgram(program);
223-
h->SetPrefetchPreparedCtx(
224-
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
229+
h->SetPrefetchPreparedCtx(prefetch_ctx);
225230
h->SetRPCServer(rpc_server);
226231
}
227232

@@ -255,17 +260,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
255260
request_prefetch_handler_.get());
256261

257262
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
258-
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
259263
auto *program = optimize_block->Program();
260264
framework::Executor executor(dev_place);
261265

262266
// prepare for prefetch
263-
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
264-
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
267+
std::vector<int> prefetch_block_id_list;
268+
std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
269+
270+
auto prefetch_var_name_to_block_id_str =
271+
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
272+
for (const auto &prefetch_var_name_and_id :
273+
prefetch_var_name_to_block_id_str) {
274+
std::vector<std::string> pieces;
275+
split(prefetch_var_name_and_id, ':', &pieces);
276+
VLOG(3) << "after split, prefetch_var = " << pieces[0]
277+
<< ", id=" << pieces[1];
278+
PADDLE_ENFORCE_EQ(pieces.size(), 2);
279+
280+
int block_id = std::stoi(pieces[1]);
281+
prefetch_block_id_list.push_back(block_id);
282+
block_id_to_prefetch_var_name[block_id] = pieces[0];
283+
}
284+
285+
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
286+
287+
std::unordered_map<std::string,
288+
std::shared_ptr<framework::ExecutorPrepareContext>>
289+
prefetch_var_name_to_prepared_ctx;
290+
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
291+
auto block_id = prefetch_block_id_list[i];
292+
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
293+
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
294+
}
265295

266296
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
267-
&dev_ctx, &executor, program, prefetch_prepared.release(),
268-
rpc_service_.get());
297+
&dev_ctx, &executor, program,
298+
&prefetch_var_name_to_prepared_ctx, rpc_service_.get());
269299

270300
f(request_send_handler_.get());
271301
f(request_get_handler_.get());
@@ -283,7 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
283313
// Write to a file of server selected port for python use.
284314
SavePort();
285315
if (sync_mode) {
286-
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
316+
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
287317
} else {
288318
RunAsyncLoop(&executor, program);
289319
}
@@ -309,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
309339
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
310340
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
311341
"BlockID to run on server side.");
312-
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
313-
"prefetch block to run on server side.");
342+
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
343+
"prefetch blocks to run on server side.")
344+
.SetDefault({});
314345
AddAttr<int>("Fanin", "How many clients send to this server.")
315346
.SetDefault(1);
316347
}

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <atomic>
1919
#include <set>
2020
#include <string>
21+
#include <vector>
2122

2223
#include "paddle/fluid/framework/executor.h"
2324
#include "paddle/fluid/framework/lod_tensor.h"
@@ -30,7 +31,7 @@ namespace paddle {
3031
namespace operators {
3132

3233
constexpr char kOptimizeBlock[] = "OptimizeBlock";
33-
constexpr char kPrefetchBlock[] = "PrefetchBlock";
34+
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
3435

3536
void RunServer(std::shared_ptr<detail::RPCServer> service);
3637

@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
4647
void RunSyncLoop(framework::Executor* executor,
4748
framework::ProgramDesc* program,
4849
framework::Scope* recv_scope,
49-
framework::BlockDesc* prefetch_block) const;
50+
const std::vector<int>& prefetch_block_id_list) const;
5051

5152
void RunAsyncLoop(framework::Executor* executor,
5253
framework::ProgramDesc* program) const;

0 commit comments

Comments
 (0)