Skip to content

Commit be85385

Browse files
author
Yancey
authored
Merge pull request #9593 from Yancey1989/prefech_prog_on_server
run prefetch prog on server
2 parents 7d39725 + 974b253 commit be85385

File tree

12 files changed

+147
-28
lines changed

12 files changed

+147
-28
lines changed

paddle/fluid/framework/scope.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/scope.h"
1616

1717
#include <memory> // for unique_ptr
18-
#include <mutex> // for call_once
1918
#include <set>
2019
#include "glog/logging.h"
2120
#include "paddle/fluid/framework/threadpool.h"
@@ -39,6 +38,7 @@ Scope::~Scope() {
3938
}
4039

4140
Scope& Scope::NewScope() const {
41+
std::unique_lock<std::mutex> lock(mutex_);
4242
kids_.push_back(new Scope(this));
4343
return *kids_.back();
4444
}
@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
9292
}
9393

9494
void Scope::DeleteScope(Scope* scope) {
95+
std::unique_lock<std::mutex> lock(mutex_);
9596
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
9697
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
9798
this->kids_.erase(it);
@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
103104
}
104105
}
105106

106-
void Scope::EraseVars(std::vector<std::string>& var_names) {
107+
void Scope::EraseVars(const std::vector<std::string>& var_names) {
107108
std::set<std::string> var_set(var_names.begin(), var_names.end());
108109
for (auto it = vars_.begin(); it != vars_.end();) {
109110
if (var_set.find(it->first) != var_set.end()) {

paddle/fluid/framework/scope.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <list>
18+
#include <mutex> // NOLINT
1819
#include <string>
1920
#include <unordered_map>
2021
#include <vector>
@@ -51,7 +52,7 @@ class Scope {
5152
/// Create a variable with a scope-unique name.
5253
Variable* Var(std::string* name = nullptr);
5354

54-
void EraseVars(std::vector<std::string>& var_names);
55+
void EraseVars(const std::vector<std::string>& var_names);
5556

5657
/// Find a variable in the scope or any of its ancestors. Returns
5758
/// nullptr if cannot find.
@@ -88,6 +89,9 @@ class Scope {
8889
Scope const* parent_{nullptr};
8990

9091
DISABLE_COPY_AND_ASSIGN(Scope);
92+
93+
private:
94+
mutable std::mutex mutex_;
9195
};
9296
} // namespace framework
9397
} // namespace paddle

paddle/fluid/operators/detail/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
55
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
66
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
77
cares zlib protobuf sendrecvop_grpc)
8-
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
8+
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
99
endif()

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
138138
auto* var = p_scope->FindVar(in_var_name_val);
139139

140140
::grpc::ByteBuffer req;
141-
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req);
141+
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
142142

143143
// var handle
144144
VarHandle var_h;

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
138138
framework::Scope* scope,
139139
const platform::DeviceContext* dev_ctx,
140140
framework::Executor* executor,
141-
framework::ProgramDesc* program, int blkid)
141+
framework::ProgramDesc* program,
142+
framework::ExecutorPrepareContext* prefetch_ctx)
142143
: RequestBase(service, cq, dev_ctx),
143144
responder_(&ctx_),
144145
scope_(scope),
145146
executor_(executor),
146147
program_(program),
147-
blkid_(blkid) {
148+
prefetch_ctx_(prefetch_ctx) {
149+
request_.reset(new VariableResponse(scope, dev_ctx_));
148150
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
149-
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
150-
cq_, this);
151+
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
152+
cq_, cq_, this);
151153
}
152154

153155
virtual ~RequestPrefetch() {}
154156

155-
virtual std::string GetReqName() { return request_.varname(); }
157+
virtual std::string GetReqName() { return request_->Varname(); }
156158

157159
virtual void Process() {
158160
// prefetch process...
159161
::grpc::ByteBuffer reply;
160-
// TODO(Yancey1989): execute the Block which containers prefetch ops
161162

162-
VLOG(3) << "RequestPrefetch Process in";
163+
std::string var_name = request_->OutVarname();
164+
auto var_desc = program_->Block(0).FindVar(var_name);
165+
framework::Scope* local_scope = &scope_->NewScope();
166+
auto* var = local_scope->FindVar(var_name);
167+
InitializeVariable(var, var_desc->GetType());
168+
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
169+
170+
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
163171

164172
responder_.Finish(reply, ::grpc::Status::OK, this);
165173
status_ = FINISH;
166174
}
167175

168176
protected:
169-
sendrecv::VariableMessage request_;
177+
std::shared_ptr<VariableResponse> request_;
170178
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
171179
framework::Scope* scope_;
172180
framework::Executor* executor_;
173181
framework::ProgramDesc* program_;
182+
framework::ExecutorPrepareContext* prefetch_ctx_;
174183
int blkid_;
175184
};
176185

@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
268277
}
269278
RequestPrefetch* prefetch =
270279
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
271-
executor_, program_, prefetch_blk_id_);
280+
executor_, program_, prefetch_ctx_);
272281

273282
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
274283
}

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
6363

6464
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
6565

66+
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
67+
prefetch_ctx_ = prepared;
68+
}
69+
6670
int GetSelectedPort() { return selected_port_; }
6771

6872
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
111115
std::unique_ptr<std::thread> t_prefetch_;
112116

113117
int prefetch_blk_id_;
118+
framework::ExecutorPrepareContext *prefetch_ctx_;
114119
framework::ProgramDesc *program_;
115120
framework::Executor *executor_;
116121
int selected_port_;

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,121 @@ limitations under the License. */
2020
#include "paddle/fluid/operators/detail/grpc_client.h"
2121
#include "paddle/fluid/operators/detail/grpc_server.h"
2222

23+
#include "paddle/fluid/framework/block_desc.h"
24+
#include "paddle/fluid/framework/op_registry.h"
25+
#include "paddle/fluid/framework/operator.h"
26+
2327
namespace framework = paddle::framework;
2428
namespace platform = paddle::platform;
2529
namespace detail = paddle::operators::detail;
2630

31+
USE_OP(lookup_table);
32+
2733
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
2834

35+
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
36+
auto root_block = program->MutableBlock(0);
37+
auto* block = program->AppendBlock(*root_block);
38+
39+
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
40+
framework::VariableNameMap output({{"Output", {"out"}}});
41+
auto op = block->AppendOp();
42+
op->SetType("lookup_table");
43+
op->SetInput("W", {"w"});
44+
op->SetInput("Ids", {"ids"});
45+
op->SetOutput("Out", {"out"});
46+
47+
auto& out = *root_block->Var("out");
48+
out.SetType(framework::proto::VarType::SELECTED_ROWS);
49+
out.SetShape({10, 10});
50+
51+
return block;
52+
}
53+
54+
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
55+
auto w_var = scope->Var("w");
56+
w_var->GetMutable<framework::SelectedRows>();
57+
58+
auto out_var = scope->Var("out");
59+
out_var->GetMutable<framework::SelectedRows>();
60+
61+
auto ids_var = scope->Var("ids");
62+
ids_var->GetMutable<framework::SelectedRows>();
63+
}
64+
65+
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
66+
int64_t rows_numel) {
67+
CreateVarsOnScope(scope, place);
68+
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
69+
auto rows = ids_var->mutable_rows();
70+
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
71+
ids_var->mutable_value()->Resize({rows_numel, 1});
72+
ids_var->mutable_value()->mutable_data<float>(*place);
73+
}
74+
75+
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
76+
int64_t rows_numel) {
77+
CreateVarsOnScope(scope, place);
78+
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
79+
auto rows = w->mutable_rows();
80+
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
81+
auto w_value = w->mutable_value();
82+
w_value->Resize({rows_numel, 10});
83+
84+
auto ptr = w_value->mutable_data<float>(*place);
85+
86+
for (int64_t i = 0; i < w_value->numel(); ++i) {
87+
ptr[i] = static_cast<float>(i / 10);
88+
}
89+
}
90+
2991
void StartServer(const std::string& endpoint) {
3092
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
93+
framework::ProgramDesc program;
94+
framework::Scope scope;
95+
platform::CPUPlace place;
96+
framework::Executor exe(place);
97+
platform::CPUDeviceContext ctx(place);
98+
auto* block = AppendPrefetchBlcok(&program);
99+
auto prepared = exe.Prepare(program, block->ID());
100+
InitTensorsOnServer(&scope, &place, 10);
101+
102+
rpc_service_->SetProgram(&program);
103+
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
104+
rpc_service_->SetDevCtx(&ctx);
105+
rpc_service_->SetScope(&scope);
106+
rpc_service_->SetExecutor(&exe);
107+
31108
rpc_service_->RunSyncUpdate();
32109
}
33110

34111
TEST(PREFETCH, CPU) {
35112
// start up a server instance backend
36-
// TODO(Yancey1989): Need to start a server with optimize blocks and
37-
// prefetch blocks.
38113
std::thread server_thread(StartServer, "127.0.0.1:8889");
114+
sleep(2);
39115
framework::Scope scope;
40116
platform::CPUPlace place;
41117
platform::CPUDeviceContext ctx(place);
42118
// create var on local scope
43-
std::string in_var_name("in");
119+
int64_t rows_numel = 5;
120+
InitTensorsOnClient(&scope, &place, rows_numel);
121+
std::string in_var_name("ids");
44122
std::string out_var_name("out");
45-
auto* in_var = scope.Var(in_var_name);
46-
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
47-
in_tensor->Resize({10, 10});
48-
VLOG(3) << "before mutable_data";
49-
in_tensor->mutable_data<int>(place);
50123

51-
scope.Var(out_var_name);
52-
53-
VLOG(3) << "before fetch";
54124
detail::RPCClient client;
55125
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
56126
out_var_name);
57127
client.Wait();
58128

129+
auto var = scope.Var(out_var_name);
130+
auto value = var->GetMutable<framework::SelectedRows>()->value();
131+
auto ptr = value.mutable_data<float>(place);
132+
59133
rpc_service_->ShutDown();
60134
server_thread.join();
61135
rpc_service_.reset(nullptr);
136+
137+
for (int64_t i = 0; i < rows_numel; ++i) {
138+
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
139+
}
62140
}

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ service SendRecvService {
2121
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
2222
// Argument VariableMessage for GetVariable should only contain varname.
2323
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
24-
// Prefetch variable by Ids
24+
// pre-fetch variable by given variable name and Ids
2525
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
2626
}
2727

@@ -67,6 +67,8 @@ message VariableMessage {
6767
bytes serialized = 8;
6868
// selected_rows data
6969
bytes rows = 9;
70+
// Look up table block execution output variable name.
71+
string out_varname = 10;
7072
}
7173

7274
message VoidMessage {}

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ namespace detail {
3030

3131
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
3232
const platform::DeviceContext& ctx,
33-
::grpc::ByteBuffer* msg) {
33+
::grpc::ByteBuffer* msg,
34+
const std::string& out_name) {
3435
using VarMsg = sendrecv::VariableMessage;
3536
sendrecv::VariableMessage request;
3637
std::string header;
@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
5253
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
5354
}
5455

56+
if (!out_name.empty()) {
57+
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
58+
}
5559
switch (framework::ToVarType(var->Type())) {
5660
case framework::proto::VarType_Type_LOD_TENSOR: {
5761
auto tensor = var->Get<framework::LoDTensor>();

paddle/fluid/operators/detail/sendrecvop_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);
4646

4747
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
4848
const platform::DeviceContext& ctx,
49-
::grpc::ByteBuffer* msg);
49+
::grpc::ByteBuffer* msg,
50+
const std::string& out_varname = std::string());
5051

5152
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
5253
const platform::DeviceContext& ctx,

0 commit comments

Comments
 (0)