Skip to content

Commit 4e36c0e

Browse files
committed
update prefetch logic in grpc_server
1 parent 0d3d4ae commit 4e36c0e

File tree

6 files changed

+86
-46
lines changed

6 files changed

+86
-46
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,18 @@ class RequestPrefetch final : public RequestBase {
155155

156156
void Process() override {
157157
// prefetch process...
158-
std::string varname = request_->OutVarname();
159-
VLOG(3) << "RequestPrefetch " << varname;
158+
std::string in_var_name = request_->Varname();
159+
std::string out_var_name = request_->OutVarname();
160+
VLOG(3) << "in_var_name: " << in_var_name
161+
<< " RequestPrefetch: " << out_var_name;
160162

161163
auto scope = request_->GetMutableLocalScope();
162-
auto invar = scope->FindVar(varname);
164+
auto invar = scope->FindVar(in_var_name);
163165
framework::Variable* outvar = nullptr;
164166

165-
request_handler_->Handle(varname, scope, invar, &outvar);
167+
request_handler_->Handle(in_var_name, scope, invar, &outvar);
166168

167-
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
169+
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
168170
&reply_);
169171
responder_.Finish(reply_, ::grpc::Status::OK,
170172
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,17 @@ void StartServer() {
9999
framework::Executor exe(place);
100100
platform::CPUDeviceContext ctx(place);
101101
auto* block = AppendPrefetchBlcok(&program);
102-
auto prepared = exe.Prepare(program, block->ID());
102+
std::string in_var_name("ids");
103+
std::vector<int> prefetch_block_ids{block->ID()};
104+
auto prepared = exe.Prepare(program, prefetch_block_ids);
103105
InitTensorsOnServer(&scope, &place, 10);
104106

107+
std::unordered_map<std::string,
108+
std::shared_ptr<framework::ExecutorPrepareContext>>
109+
prefetch_var_name_to_prepared;
110+
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
105111
g_req_handler->SetProgram(&program);
106-
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
112+
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
107113
g_req_handler->SetDevCtx(&ctx);
108114
g_req_handler->SetScope(&scope);
109115
g_req_handler->SetExecutor(&exe);

paddle/fluid/operators/detail/request_handler.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ class RequestHandler {
5757
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
5858
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
5959
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
60+
61+
// Used for dist lookup table prefetch
6062
void SetPrefetchPreparedCtx(
61-
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
62-
prefetch_ctx_.reset(prepared.release());
63+
std::unordered_map<
64+
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
65+
prefetch_var_name_to_prepared_ctx_ = g;
6366
}
6467

6568
// Used for async.
@@ -75,9 +78,6 @@ class RequestHandler {
7578
bool sync_mode() { return sync_mode_; }
7679
framework::Scope* scope() { return scope_; }
7780
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
78-
framework::ExecutorPrepareContext* prefetch_ctx() {
79-
return prefetch_ctx_.get();
80-
}
8181
framework::ProgramDesc* program() { return program_; }
8282
framework::Executor* executor() { return executor_; }
8383

@@ -106,12 +106,17 @@ class RequestHandler {
106106
framework::Executor* executor_;
107107
framework::Scope* scope_;
108108
framework::ProgramDesc* program_;
109-
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
109+
110+
// used for distribute lookup table prefetch
111+
std::unordered_map<std::string,
112+
std::shared_ptr<framework::ExecutorPrepareContext>>*
113+
prefetch_var_name_to_prepared_ctx_;
110114

111115
// Used for async.
112116
std::unordered_map<std::string,
113117
std::shared_ptr<framework::ExecutorPrepareContext>>*
114118
grad_to_prepared_ctx_;
119+
115120
RPCServer* rpc_server_;
116121
};
117122

paddle/fluid/operators/detail/request_handler_impl.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
111111
auto var_desc = program_->Block(0).FindVar(varname);
112112
*outvar = scope->FindVar(varname);
113113
InitializeVariable(*outvar, var_desc->GetType());
114-
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
114+
executor_->RunPreparedContext(
115+
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
115116

116117
return true;
117118
}

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,19 @@ void ListenAndServOp::SavePort() const {
8989
rpc_service_->SavePort();
9090
}
9191

92-
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
93-
framework::ProgramDesc *program,
94-
framework::Scope *recv_scope,
95-
framework::BlockDesc *prefetch_block) const {
92+
void ListenAndServOp::RunSyncLoop(
93+
framework::Executor *executor, framework::ProgramDesc *program,
94+
framework::Scope *recv_scope,
95+
const std::vector<int> &prefetch_block_id_list) const {
96+
// FIXME(qiao) run should not run the block to do prefetch, currently prefetch
97+
// block
98+
// can only be at the last blocks of the program
9699
size_t num_blocks = program->Size();
97100
PADDLE_ENFORCE_GE(num_blocks, 2,
98101
"server program should have at least 2 blocks");
99102

100103
std::vector<int> block_list;
101-
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
104+
for (size_t blkid = 1; blkid < prefetch_block_id_list[0]; ++blkid) {
102105
block_list.push_back(blkid);
103106
}
104107
auto optimize_prepared = executor->Prepare(*program, block_list);
@@ -128,16 +131,14 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
128131
std::vector<size_t> parallel_blkids;
129132
parallel_blkids.push_back(1);
130133
double ts = detail::GetTimestamp();
131-
for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
132-
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
133-
if (program->Block(blkid).Parent() != last_parent_blkid) {
134-
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
135-
program, recv_scope);
136-
parallel_blkids.clear();
137-
last_parent_blkid = program->Block(blkid).Parent();
138-
}
139-
parallel_blkids.push_back(blkid);
134+
for (size_t blkid = 2; blkid < prefetch_block_id_list[0]; ++blkid) {
135+
if (program->Block(blkid).Parent() != last_parent_blkid) {
136+
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
137+
program, recv_scope);
138+
parallel_blkids.clear();
139+
last_parent_blkid = program->Block(blkid).Parent();
140140
}
141+
parallel_blkids.push_back(blkid);
141142
}
142143
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
143144
recv_scope);
@@ -203,18 +204,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
203204
} // while(true)
204205
}
205206

206-
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
207-
platform::DeviceContext *dev_ctx,
208-
framework::Executor *executor,
209-
framework::ProgramDesc *program,
210-
framework::ExecutorPrepareContext *prefetch_ctx,
211-
detail::RPCServer *rpc_server) {
207+
static void FillRequestCtx(
208+
detail::RequestHandler *h, framework::Scope *scope,
209+
platform::DeviceContext *dev_ctx, framework::Executor *executor,
210+
framework::ProgramDesc *program,
211+
std::unordered_map<std::string,
212+
std::shared_ptr<framework::ExecutorPrepareContext>>
213+
*prefetch_ctx,
214+
detail::RPCServer *rpc_server) {
212215
h->SetScope(scope);
213216
h->SetDevCtx(dev_ctx);
214217
h->SetExecutor(executor);
215218
h->SetProgram(program);
216-
h->SetPrefetchPreparedCtx(
217-
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx));
219+
h->SetPrefetchPreparedCtx(prefetch_ctx);
218220
h->SetRPCServer(rpc_server);
219221
}
220222

@@ -248,18 +250,41 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
248250
request_prefetch_handler_.get());
249251

250252
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
251-
auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock);
252-
framework::BlockDesc *prefetch_block = nullptr;
253253
auto *program = optimize_block->Program();
254254
framework::Executor executor(dev_place);
255255

256256
// prepare for prefetch
257-
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
258-
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
257+
std::vector<int> prefetch_block_id_list;
258+
std::unordered_map<int32_t, std::string> block_id_to_prefetch_var_name;
259+
260+
auto prefetch_var_name_to_block_id_str =
261+
Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
262+
for (const auto &prefetch_var_name_and_id :
263+
prefetch_var_name_to_block_id_str) {
264+
std::vector<std::string> pieces;
265+
split(prefetch_var_name_and_id, ':', &pieces);
266+
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
267+
PADDLE_ENFORCE_EQ(pieces.size(), 2);
268+
269+
int block_id = std::stoi(pieces[1]);
270+
prefetch_block_id_list.push_back(block_id);
271+
block_id_to_prefetch_var_name[block_id] = pieces[0];
272+
}
273+
274+
auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);
275+
276+
std::unordered_map<std::string,
277+
std::shared_ptr<framework::ExecutorPrepareContext>>
278+
prefetch_var_name_to_prepared_ctx;
279+
for (int i = 0; i < prefetch_block_id_list.size(); ++i) {
280+
auto block_id = prefetch_block_id_list[i];
281+
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
282+
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
283+
}
259284

260285
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
261-
&dev_ctx, &executor, program, prefetch_prepared.release(),
262-
rpc_service_.get());
286+
&dev_ctx, &executor, program,
287+
&prefetch_var_name_to_prepared_ctx, rpc_service_.get());
263288

264289
f(request_send_handler_.get());
265290
f(request_get_handler_.get());
@@ -277,7 +302,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
277302
// Write to a file of server selected port for python use.
278303
SavePort();
279304
if (sync_mode) {
280-
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
305+
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list);
281306
} else {
282307
RunAsyncLoop(&executor, program);
283308
}
@@ -303,7 +328,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
303328
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
304329
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
305330
"BlockID to run on server side.");
306-
AddAttr<std::vector<std::string>>(kPrefetchBlock,
331+
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
307332
"prefetch block to run on server side.");
308333
AddAttr<int>("Fanin", "How many clients send to this server.")
309334
.SetDefault(1);

paddle/fluid/operators/listen_and_serv_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <atomic>
1919
#include <set>
2020
#include <string>
21+
#include <vector>
2122

2223
#include "paddle/fluid/framework/executor.h"
2324
#include "paddle/fluid/framework/lod_tensor.h"
@@ -30,7 +31,7 @@ namespace paddle {
3031
namespace operators {
3132

3233
constexpr char kOptimizeBlock[] = "OptimizeBlock";
33-
constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id";
34+
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
3435

3536
void RunServer(std::shared_ptr<detail::RPCServer> service);
3637

@@ -46,7 +47,7 @@ class ListenAndServOp : public framework::OperatorBase {
4647
void RunSyncLoop(framework::Executor* executor,
4748
framework::ProgramDesc* program,
4849
framework::Scope* recv_scope,
49-
framework::BlockDesc* prefetch_block) const;
50+
const std::vector<int>& prefetch_block_id_list) const;
5051

5152
void RunAsyncLoop(framework::Executor* executor,
5253
framework::ProgramDesc* program) const;

0 commit comments

Comments
 (0)