Skip to content

Commit b55dc9a

Browse files
author
Yancey
authored
Merge pull request #9536 from Yancey1989/prefetch_on_server
Add prefetch interface on server side
2 parents ef802ce + eb04ccb commit b55dc9a

File tree

7 files changed

+135
-3
lines changed

7 files changed

+135
-3
lines changed

paddle/fluid/operators/detail/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ if(WITH_DISTRIBUTE)
22
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
33
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
44
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
5-
set_source_files_properties(serde_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
5+
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
66
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
77
cares zlib protobuf sendrecvop_grpc)
8+
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
89
endif()

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
150150
s->response_call_back_ = ProcGetResponse;
151151

152152
auto call = s->stub_g_.PrepareUnaryCall(
153-
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", req, &cq_);
153+
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
154+
&cq_);
154155
call->StartCall();
155156
call->Finish(&s->reply_, &s->status_, (void*)s);
156157
});

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,47 @@ class RequestGet final : public RequestBase {
128128
SimpleBlockQueue<MessageWithName>* queue_;
129129
};
130130

131+
class RequestPrefetch final : public RequestBase {
132+
public:
133+
explicit RequestPrefetch(GrpcService::AsyncService* service,
134+
::grpc::ServerCompletionQueue* cq,
135+
framework::Scope* scope,
136+
const platform::DeviceContext* dev_ctx,
137+
framework::Executor* executor,
138+
framework::ProgramDesc* program, int blkid)
139+
: RequestBase(service, cq, dev_ctx),
140+
responder_(&ctx_),
141+
scope_(scope),
142+
executor_(executor),
143+
program_(program),
144+
blkid_(blkid) {
145+
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
146+
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
147+
cq_, this);
148+
}
149+
150+
virtual ~RequestPrefetch() {}
151+
152+
virtual std::string GetReqName() { return request_.varname(); }
153+
154+
virtual void Process() {
155+
// prefetch process...
156+
::grpc::ByteBuffer reply;
157+
// TODO(Yancey1989): execute the Block which containers prefetch ops
158+
159+
responder_.Finish(reply, ::grpc::Status::OK, this);
160+
status_ = FINISH;
161+
}
162+
163+
protected:
164+
sendrecv::VariableMessage request_;
165+
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
166+
framework::Scope* scope_;
167+
framework::Executor* executor_;
168+
framework::ProgramDesc* program_;
169+
int blkid_;
170+
};
171+
131172
void AsyncGRPCServer::WaitClientGet(int count) {
132173
int fetch_barriers = 0;
133174
while (fetch_barriers < count) {
@@ -147,6 +188,7 @@ void AsyncGRPCServer::RunSyncUpdate() {
147188

148189
cq_send_ = builder.AddCompletionQueue();
149190
cq_get_ = builder.AddCompletionQueue();
191+
cq_prefetch_ = builder.AddCompletionQueue();
150192

151193
server_ = builder.BuildAndStart();
152194
LOG(INFO) << "Server listening on " << address_ << std::endl;
@@ -155,6 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() {
155197
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
156198
std::function<void()> get_register =
157199
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
200+
std::function<void()> prefetch_register =
201+
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this);
158202

159203
t_send_.reset(
160204
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
@@ -163,11 +207,14 @@ void AsyncGRPCServer::RunSyncUpdate() {
163207
t_get_.reset(
164208
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
165209
cq_get_.get(), "cq_get", get_register)));
166-
210+
t_prefetch_.reset(new std::thread(
211+
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
212+
"cq_prefetch", prefetch_register)));
167213
// wait server
168214
server_->Wait();
169215
t_send_->join();
170216
t_get_->join();
217+
t_prefetch_->join();
171218
}
172219

173220
void AsyncGRPCServer::ShutdownQueue() {
@@ -203,6 +250,18 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
203250
VLOG(4) << "Create RequestGet status:" << get->Status();
204251
}
205252

253+
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
254+
std::unique_lock<std::mutex> lock(cq_mutex_);
255+
if (is_shut_down_) {
256+
return;
257+
}
258+
RequestPrefetch* prefetch =
259+
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
260+
executor_, program_, prefetch_blk_id_);
261+
262+
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
263+
}
264+
206265
// FIXME(typhoonzero): change cq_name to enum.
207266
void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
208267
std::string cq_name,

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License. */
1717
#include <grpc++/grpc++.h>
1818
#include <thread>
1919

20+
#include "paddle/fluid/framework/executor.h"
2021
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/program_desc.h"
2123
#include "paddle/fluid/framework/scope.h"
2224
#include "paddle/fluid/framework/selected_rows.h"
2325
#include "paddle/fluid/framework/var_type.h"
@@ -53,6 +55,12 @@ class AsyncGRPCServer final {
5355

5456
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
5557

58+
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
59+
60+
void SetPrefetchBlkdId(int blkid) { prefetch_blk_id_ = blkid; }
61+
62+
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
63+
5664
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
5765

5866
void Push(const std::string &msg_name) {
@@ -66,13 +74,15 @@ class AsyncGRPCServer final {
6674
std::function<void()> TryToRegisterNewOne);
6775
void TryToRegisterNewSendOne();
6876
void TryToRegisterNewGetOne();
77+
void TryToRegisterNewPrefetchOne();
6978
void ShutdownQueue();
7079

7180
private:
7281
std::mutex cq_mutex_;
7382
volatile bool is_shut_down_ = false;
7483
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
7584
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
85+
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
7686

7787
GrpcService::AsyncService service_;
7888
std::unique_ptr<::grpc::Server> server_;
@@ -92,6 +102,11 @@ class AsyncGRPCServer final {
92102

93103
std::unique_ptr<std::thread> t_send_;
94104
std::unique_ptr<std::thread> t_get_;
105+
std::unique_ptr<std::thread> t_prefetch_;
106+
107+
int prefetch_blk_id_;
108+
framework::ProgramDesc *program_;
109+
framework::Executor *executor_;
95110
};
96111

97112
}; // namespace detail
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <unistd.h>
16+
#include <string>
17+
#include <thread>
18+
19+
#include "gtest/gtest.h"
20+
#include "paddle/fluid/operators/detail/grpc_client.h"
21+
#include "paddle/fluid/operators/detail/grpc_server.h"
22+
23+
namespace framework = paddle::framework;
24+
namespace platform = paddle::platform;
25+
namespace detail = paddle::operators::detail;
26+
27+
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
28+
29+
void StartServer(const std::string& endpoint) {
30+
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
31+
}
32+
33+
TEST(PREFETCH, CPU) {
34+
// start up a server instance backend
35+
// TODO(Yancey1989): Need to start a server with optimize blocks and
36+
// prefetch blocks.
37+
std::thread server_thread(StartServer, "127.0.0.1:8889");
38+
framework::Scope scope;
39+
platform::CPUPlace place;
40+
platform::CPUDeviceContext ctx(place);
41+
// create var on local scope
42+
std::string var_name("tmp_0");
43+
auto var = scope.Var(var_name);
44+
auto tensor = var->GetMutable<framework::LoDTensor>();
45+
tensor->Resize({10, 10});
46+
47+
detail::RPCClient client;
48+
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
49+
server_thread.join();
50+
rpc_service_.reset(nullptr);
51+
}

paddle/fluid/operators/detail/grpc_service.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ namespace detail {
7676
enum class GrpcMethod {
7777
kSendVariable,
7878
kGetVariable,
79+
kPrefetchVariable,
7980
};
8081

8182
static const int kGrpcNumMethods =
@@ -87,6 +88,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
8788
return "/sendrecv.SendRecvService/SendVariable";
8889
case GrpcMethod::kGetVariable:
8990
return "/sendrecv.SendRecvService/GetVariable";
91+
case GrpcMethod::kPrefetchVariable:
92+
return "/sendrecv.SendREcvService/PrefetchVariable";
9093
}
9194

9295
// Shouldn't be reached.

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ service SendRecvService {
2121
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
2222
// Argument VariableMessage for GetVariable should only contain varname.
2323
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
24+
// Prefetch variable by Ids
25+
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
2426
}
2527

2628
// VariableMessage is serialized paddle variable message.

0 commit comments

Comments
 (0)