Skip to content

Commit 09fcf5f

Browse files
author
Yancey
authored
Merge pull request #9555 from jacquesqiao/improve-prefetch-on-server
Improve prefetch on server
2 parents b9d8bbe + 04a5c03 commit 09fcf5f

File tree

7 files changed

+55
-32
lines changed

7 files changed

+55
-32
lines changed

paddle/fluid/operators/detail/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
22
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
33
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
44
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
5-
set_source_files_properties(serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
5+
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
66
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
77
cares zlib protobuf sendrecvop_grpc)
88
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "grpc_client.h"
16-
#include <sys/time.h>
15+
#include "paddle/fluid/operators/detail/grpc_client.h"
16+
17+
#include <limits>
18+
1719
#include "paddle/fluid/framework/threadpool.h"
1820

1921
namespace paddle {
@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
5254
auto call = s->stub_g_.PrepareUnaryCall(
5355
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
5456
call->StartCall();
55-
call->Finish(&s->reply_, &s->status_, (void*)s);
57+
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
5658
});
5759

5860
req_count_++;
@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
7072
template <typename T>
7173
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
7274
::grpc::Slice slice(proto.ByteSizeLong());
73-
proto.SerializeWithCachedSizesToArray(
74-
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(slice.begin())));
75+
proto.SerializeWithCachedSizesToArray(const_cast<uint8_t*>(slice.begin()));
7576
::grpc::ByteBuffer tmp(&slice, 1);
7677
result->Swap(&tmp);
7778
}
@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
109110
auto call = s->stub_g_.PrepareUnaryCall(
110111
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
111112
call->StartCall();
112-
call->Finish(&s->reply_, &s->status_, (void*)s);
113+
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
113114
});
114115

115116
req_count_++;
@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
153154
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
154155
&cq_);
155156
call->StartCall();
156-
call->Finish(&s->reply_, &s->status_, (void*)s);
157+
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
157158
});
158159

159160
req_count_++;
@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
169170
sendrecv::VariableMessage req;
170171
req.set_varname(BATCH_BARRIER_MESSAGE);
171172
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
172-
rpc->Finish(&s->reply_, &s->status_, (void*)s);
173+
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
173174
req_count_++;
174175
}
175176

@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
181182
sendrecv::VariableMessage req;
182183
req.set_varname(FETCH_BARRIER_MESSAGE);
183184
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
184-
rpc->Finish(&s->reply_, &s->status_, (void*)s);
185+
rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
185186
req_count_++;
186187
}
187188

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/detail/grpc_server.h"
1616

17+
#include <limits>
18+
#include <string>
19+
1720
using ::grpc::ServerAsyncResponseWriter;
1821

1922
namespace paddle {
@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
156159
::grpc::ByteBuffer reply;
157160
// TODO(Yancey1989): execute the Block which containers prefetch ops
158161

162+
VLOG(3) << "RequestPrefetch Process in";
163+
159164
responder_.Finish(reply, ::grpc::Status::OK, this);
160165
status_ = FINISH;
161166
}
@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
221226
std::unique_lock<std::mutex> lock(cq_mutex_);
222227
cq_send_->Shutdown();
223228
cq_get_->Shutdown();
229+
cq_prefetch_->Shutdown();
224230
}
225231

226232
// This URL explains why shutdown is complicate:
@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
233239
void AsyncGRPCServer::TryToRegisterNewSendOne() {
234240
std::unique_lock<std::mutex> lock(cq_mutex_);
235241
if (is_shut_down_) {
242+
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
236243
return;
237244
}
238245
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
243250
void AsyncGRPCServer::TryToRegisterNewGetOne() {
244251
std::unique_lock<std::mutex> lock(cq_mutex_);
245252
if (is_shut_down_) {
253+
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
246254
return;
247255
}
248256
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
253261
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
254262
std::unique_lock<std::mutex> lock(cq_mutex_);
255263
if (is_shut_down_) {
264+
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
256265
return;
257266
}
258267
RequestPrefetch* prefetch =
@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
270279

271280
void* tag = NULL;
272281
bool ok = false;
282+
273283
while (true) {
284+
VLOG(3) << "HandleRequest for " << cq_name << " while in";
274285
if (!cq->Next(&tag, &ok)) {
275286
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
276287
break;
277288
}
289+
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
278290

279291
PADDLE_ENFORCE(tag);
280292
// FIXME(typhoonzero): de-couple the barriers with recv_op
281293
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
282294
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
283295

284-
RequestBase* base = (RequestBase*)tag;
296+
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
285297
// reference:
286298
// https://github.com/tensorflow/tensorflow/issues/5596
287299
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
288300
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
289301
if (!ok) {
290-
LOG(WARNING) << cq_name << " recv no regular event:argument name"
291-
<< base->GetReqName();
302+
LOG(WARNING) << cq_name << " recv no regular event:argument name["
303+
<< base->GetReqName() << "]";
292304
TryToRegisterNewOne();
293305
delete base;
294306
continue;

paddle/fluid/operators/detail/grpc_server.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <grpc++/grpc++.h>
18-
#include <thread>
18+
#include <string>
19+
#include <utility>
1920

2021
#include "paddle/fluid/framework/executor.h"
2122
#include "paddle/fluid/framework/lod_tensor.h"
@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
9394

9495
// received variable from RPC, operators fetch variable from this queue.
9596
SimpleBlockQueue<MessageWithName> var_get_queue_;
97+
// client send variable to this queue.
9698
ReceivedQueue var_recv_queue_;
9799

98100
// condition of the sub program

paddle/fluid/operators/detail/grpc_server_test.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
2828

2929
void StartServer(const std::string& endpoint) {
3030
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
31+
rpc_service_->RunSyncUpdate();
3132
}
3233

3334
TEST(PREFETCH, CPU) {
@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
3940
platform::CPUPlace place;
4041
platform::CPUDeviceContext ctx(place);
4142
// create var on local scope
42-
std::string var_name("tmp_0");
43-
auto var = scope.Var(var_name);
44-
auto tensor = var->GetMutable<framework::LoDTensor>();
45-
tensor->Resize({10, 10});
43+
std::string in_var_name("in");
44+
std::string out_var_name("out");
45+
auto* in_var = scope.Var(in_var_name);
46+
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
47+
in_tensor->Resize({10, 10});
48+
VLOG(3) << "before mutable_data";
49+
in_tensor->mutable_data<int>(place);
4650

51+
scope.Var(out_var_name);
52+
53+
VLOG(3) << "before fetch";
4754
detail::RPCClient client;
48-
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
55+
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
56+
out_var_name);
57+
client.Wait();
58+
59+
rpc_service_->ShutDown();
4960
server_thread.join();
5061
rpc_service_.reset(nullptr);
5162
}

paddle/fluid/operators/detail/grpc_service.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ enum class GrpcMethod {
8080
};
8181

8282
static const int kGrpcNumMethods =
83-
static_cast<int>(GrpcMethod::kGetVariable) + 1;
83+
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
8484

8585
inline const char* GrpcMethodName(GrpcMethod id) {
8686
switch (id) {
@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
8989
case GrpcMethod::kGetVariable:
9090
return "/sendrecv.SendRecvService/GetVariable";
9191
case GrpcMethod::kPrefetchVariable:
92-
return "/sendrecv.SendREcvService/PrefetchVariable";
92+
return "/sendrecv.SendRecvService/PrefetchVariable";
9393
}
9494

9595
// Shouldn't be reached.
@@ -117,5 +117,5 @@ class GrpcService final {
117117
};
118118

119119
} // namespace detail
120-
} // namespace operator
120+
} // namespace operators
121121
} // namespace paddle

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <stdint.h>
16-
#include <sys/stat.h>
1716
#include <ostream>
18-
#include <thread>
19-
20-
#include <unistd.h>
2117

2218
#include "paddle/fluid/framework/executor.h"
23-
#include "paddle/fluid/framework/framework.pb.h"
2419
#include "paddle/fluid/framework/lod_tensor.h"
2520
#include "paddle/fluid/framework/op_registry.h"
26-
#include "paddle/fluid/framework/proto_desc.h"
2721
#include "paddle/fluid/framework/threadpool.h"
2822
#include "paddle/fluid/operators/detail/grpc_server.h"
29-
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
30-
#include "paddle/fluid/operators/detail/simple_block_queue.h"
31-
#include "paddle/fluid/string/printf.h"
3223

3324
namespace paddle {
3425
namespace operators {
@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
111102

112103
framework::Executor executor(dev_place);
113104

105+
// TODO(qiao) set proper fields for table lookup and update
106+
rpc_service_->SetExecutor(&executor);
107+
rpc_service_->SetPrefetchBlkdId(0);
108+
rpc_service_->SetProgram(program);
109+
114110
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
115111
bool exit_flag = false;
116112
// Record received sparse variables, so that
@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
173169
}
174170
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
175171

176-
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;
172+
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
173+
<< "(ms)";
177174

178175
// Reset the received sparse variables, the sum operator would not
179176
// sum the input sparse variables which rows is empty at the next

0 commit comments

Comments
 (0)