Skip to content

Commit f132f51

Browse files
committed
prepare prefetch context
1 parent 4698966 commit f132f51

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ class RequestPrefetch final : public RequestBase {
138138
framework::Scope* scope,
139139
const platform::DeviceContext* dev_ctx,
140140
framework::Executor* executor,
141-
framework::ProgramDesc* program, int blkid)
141+
framework::ProgramDesc* program,
142+
framework::ExecutorPrepareContext* prefetch_ctx)
142143
: RequestBase(service, cq, dev_ctx),
143144
responder_(&ctx_),
144145
scope_(scope),
145146
executor_(executor),
146147
program_(program),
147-
blkid_(blkid) {
148+
prefetch_ctx_(prefetch_ctx) {
148149
request_.reset(new VariableResponse(scope, dev_ctx_));
149150
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
150151
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
@@ -164,8 +165,7 @@ class RequestPrefetch final : public RequestBase {
164165
framework::Scope* local_scope = &scope_->NewScope();
165166
auto* var = local_scope->FindVar(var_name);
166167
InitializeVariable(var, var_desc->GetType());
167-
168-
executor_->Run(*program_, local_scope, blkid_, false, false);
168+
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
169169

170170
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
171171

@@ -179,6 +179,7 @@ class RequestPrefetch final : public RequestBase {
179179
framework::Scope* scope_;
180180
framework::Executor* executor_;
181181
framework::ProgramDesc* program_;
182+
framework::ExecutorPrepareContext* prefetch_ctx_;
182183
int blkid_;
183184
};
184185

@@ -276,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
276277
}
277278
RequestPrefetch* prefetch =
278279
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
279-
executor_, program_, prefetch_blk_id_);
280+
executor_, program_, prefetch_ctx_);
280281

281282
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
282283
}

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
6363

6464
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
6565

66+
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
67+
prefetch_ctx_ = prepared;
68+
}
69+
6670
int GetSelectedPort() { return selected_port_; }
6771

6872
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
111115
std::unique_ptr<std::thread> t_prefetch_;
112116

113117
int prefetch_blk_id_;
118+
framework::ExecutorPrepareContext *prefetch_ctx_;
114119
framework::ProgramDesc *program_;
115120
framework::Executor *executor_;
116121
int selected_port_;

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ void StartServer(const std::string& endpoint) {
9696
framework::Executor exe(place);
9797
platform::CPUDeviceContext ctx(place);
9898
auto* block = AppendPrefetchBlcok(&program);
99+
auto prepared = exe.Prepare(program, block->ID());
99100
InitTensorsOnServer(&scope, &place, 10);
100101

101102
rpc_service_->SetProgram(&program);
102-
rpc_service_->SetPrefetchBlkdId(block->ID());
103+
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
103104
rpc_service_->SetDevCtx(&ctx);
104105
rpc_service_->SetScope(&scope);
105106
rpc_service_->SetExecutor(&exe);
@@ -125,7 +126,6 @@ TEST(PREFETCH, CPU) {
125126
out_var_name);
126127
client.Wait();
127128

128-
// auto out_var = scope.Var(out_var_name);
129129
auto var = scope.Var(out_var_name);
130130
auto value = var->GetMutable<framework::SelectedRows>()->value();
131131
auto ptr = value.mutable_data<float>(place);

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ service SendRecvService {
2121
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
2222
// Argument VariableMessage for GetVariable should only contain varname.
2323
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
24-
// Prefetch variable by Ids
24+
// Look up table block execution output variable name.
2525
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
2626
}
2727

0 commit comments

Comments
 (0)