Skip to content

Commit 3a6213f

Browse files
authored
Change grpc interface to compatible with brpc. (#12164)
1 parent b063093 commit 3a6213f

19 files changed

+748
-550
lines changed
Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,43 @@
1+
if(NOT WITH_DISTRIBUTE)
2+
return()
3+
endif()
4+
5+
if(WITH_GRPC)
6+
set(cc_generic_services "false")
7+
else()
8+
set(cc_generic_services "true")
9+
endif()
10+
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
11+
112
if(WITH_GRPC)
2-
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
3-
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
4-
selected_rows memory)
13+
grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
14+
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
15+
PROTO send_recv.proto
16+
DEPS lod_tensor selected_rows memory)
517
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
618
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
7-
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
8-
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
9-
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc
10-
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
11-
proto_desc lookup_table_op SERIAL)
19+
cc_test(grpc_serde_test SRCS grpc_serde_test.cc
20+
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
21+
cc_test(grpc_server_test SRCS rpc_server_test.cc
22+
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
1223
return()
1324
endif()
1425

1526

1627
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
17-
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
18-
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc
28+
29+
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
30+
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
31+
32+
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
33+
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
1934
PROTO send_recv.proto
2035
DEPS lod_tensor selected_rows memory)
2136

22-
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so)
23-
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
24-
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
25-
37+
set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
2638

27-
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so)
28-
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL)
29-
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
39+
cc_test(brpc_server_test SRCS rpc_server_test.cc
40+
DEPS ${brpc_test_depends} SERIAL)
3041

31-
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc
32-
brpc protobuf leveldb gflags glog
33-
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
42+
cc_test(brpc_serde_test SRCS brpc_serde_test.cc
43+
DEPS ${brpc_test_depends} SERIAL)

paddle/fluid/operators/distributed/bytebuffer_stream.cc renamed to paddle/fluid/operators/distributed/grpc_bytebuffer_stream.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
// file and did some modifications so that we can send gRPC
1818
// requests without too much copying of the tensor data.
1919

20-
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h"
20+
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
2121

2222
namespace paddle {
2323
namespace operators {

paddle/fluid/operators/distributed/bytebuffer_stream.h renamed to paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License. */
2424
#include "google/protobuf/io/coded_stream.h"
2525
#include "google/protobuf/io/zero_copy_stream.h"
2626
#include "grpc++/grpc++.h"
27+
#include "paddle/fluid/operators/distributed/variable_response.h"
2728

2829
namespace grpc {
2930
// A ZeroCopyInputStream that reads from grpc_byte_buffer
@@ -107,25 +108,6 @@ class GrpcBufferReader final
107108
namespace paddle {
108109
namespace operators {
109110
namespace distributed {
110-
// Source provides a way for a particular RPC implementation to provide
111-
// received data to ParseFrom.
112-
class Source {
113-
public:
114-
virtual ~Source() {}
115-
116-
// Return the stream that contains the data to be parsed.
117-
// Note that this method might be invoked more than once if
118-
// ParseFrom needs to fall back to a more expensive parsing method.
119-
// Every call must return a stream pointing at the beginning of
120-
// the serialized RecvTensorResponse.
121-
//
122-
// Note that a subsequent call to contents() invalidates previous
123-
// results of contents().
124-
//
125-
// Ownership of the returned stream is retained by the Source and
126-
// should not be deleted by the caller.
127-
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
128-
};
129111

130112
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
131113
class GrpcByteBufferSource

paddle/fluid/operators/distributed/grpc_client.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020

2121
#include "glog/logging.h" // For VLOG
2222
#include "paddle/fluid/framework/threadpool.h"
23+
#include "paddle/fluid/operators/distributed/grpc_serde.h"
2324
#include "paddle/fluid/operators/distributed/request_handler.h"
2425
#include "paddle/fluid/platform/profiler.h"
2526

paddle/fluid/operators/distributed/grpc_client.h

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,17 @@ limitations under the License. */
3838
#include "paddle/fluid/framework/lod_tensor.h"
3939
#include "paddle/fluid/framework/scope.h"
4040
#include "paddle/fluid/framework/selected_rows.h"
41+
#include "paddle/fluid/operators/distributed/request_handler.h"
4142
#include "paddle/fluid/operators/distributed/rpc_client.h"
43+
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
44+
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
4245
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
4346
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
4447

4548
namespace paddle {
4649
namespace operators {
4750
namespace distributed {
4851

49-
struct VarHandle {
50-
// RPC endpoint.
51-
std::string ep;
52-
const platform::DeviceContext* ctx;
53-
const framework::Scope* scope;
54-
// Variable name.
55-
std::string name;
56-
// RPC method name.
57-
std::string method;
58-
59-
std::string String() const {
60-
std::ostringstream s;
61-
s << method << " name:[" << name << "], ep:[" << ep << "]";
62-
return s.str();
63-
}
64-
};
65-
6652
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
6753

6854
class BaseProcessor {
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_CUDA
16+
#include <nccl.h>
17+
#endif
18+
#include <sys/time.h>
19+
#include <thread> // NOLINT
20+
21+
#include "google/protobuf/io/coded_stream.h"
22+
#include "google/protobuf/io/zero_copy_stream.h"
23+
#include "paddle/fluid/framework/data_type.h"
24+
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
25+
#include "paddle/fluid/operators/distributed/grpc_serde.h"
26+
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
27+
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
28+
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
29+
#include "paddle/fluid/platform/profiler.h"
30+
31+
namespace paddle {
32+
namespace operators {
33+
namespace distributed {
34+
35+
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
36+
const platform::DeviceContext& ctx,
37+
::grpc::ByteBuffer* msg,
38+
const std::string& out_name) {
39+
// Default DestroyCallback does nothing, When using GPU
40+
// the CPU buffer need to be freed.
41+
DestroyCallback destroy_callback = [](void* backing) {};
42+
VarMsg request;
43+
void* payload = nullptr;
44+
size_t payload_size;
45+
46+
request.set_varname(name);
47+
// Note: normally the profiler is enabled in 1 trainer, hence only
48+
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
49+
// servers the trainer's profiling state so that PS can follow the
50+
// trainer.
51+
if (platform::ShouldSendProfileState()) {
52+
if (platform::IsProfileEnabled()) {
53+
request.set_profile(platform::kEnableProfiler);
54+
} else {
55+
request.set_profile(platform::kDisableProfiler);
56+
}
57+
}
58+
if (!out_name.empty()) {
59+
request.set_out_varname(out_name);
60+
}
61+
if (var->IsType<framework::LoDTensor>()) {
62+
request.set_type(::sendrecv::LOD_TENSOR);
63+
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
64+
} else if (var->IsType<framework::SelectedRows>()) {
65+
request.set_type(::sendrecv::SELECTED_ROWS);
66+
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
67+
#ifdef PADDLE_WITH_CUDA
68+
} else if (var->IsType<ncclUniqueId>()) {
69+
request.set_type(::sendrecv::NCCL_ID);
70+
#endif
71+
} else {
72+
PADDLE_THROW("Serialize does not support type: %s",
73+
typeid(var->Type()).name());
74+
}
75+
76+
if (platform::is_gpu_place(ctx.GetPlace())) {
77+
#ifdef PADDLE_WITH_CUDA
78+
// GPU data is copied to CPU buffer when sending,
79+
// free the buffer when possible.
80+
destroy_callback = [](void* backing) {
81+
platform::CUDAPinnedPlace cuda_pinned;
82+
memory::Free(cuda_pinned, backing);
83+
};
84+
#endif
85+
}
86+
87+
std::string header;
88+
request.AppendToString(&header);
89+
auto buffer = std::unique_ptr<char[]>(new char[1024]);
90+
void* buf = buffer.get();
91+
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
92+
e.WriteRawBytes(std::string(header.data(), header.size()));
93+
// NCCLID is copied directly to the message, return bytebuffer
94+
// with only one slice if serializing NCCLID.
95+
#ifdef PADDLE_WITH_CUDA
96+
if (var->IsType<ncclUniqueId>()) {
97+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
98+
NCCL_UNIQUE_ID_BYTES);
99+
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
100+
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
101+
102+
// for serialize NCCL_ID
103+
::grpc::Slice slices(e.size());
104+
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
105+
::grpc::ByteBuffer tmp(&slices, 1);
106+
msg->Swap(&tmp);
107+
return;
108+
}
109+
#endif
110+
111+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
112+
// steal reference of tensor data
113+
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
114+
int num_slices = 2; // only SelectedRows have rows buffer
115+
slices[0] = ::grpc::Slice(e.size());
116+
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
117+
slices[1] = ::grpc::Slice(
118+
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
119+
static_cast<char*>(payload)),
120+
::grpc::Slice::STEAL_REF);
121+
122+
if (var->IsType<framework::SelectedRows>()) {
123+
auto* slr = var->GetMutable<framework::SelectedRows>();
124+
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
125+
size_t rows_memory_size =
126+
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
127+
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
128+
slices[2] = ::grpc::Slice(e2.size());
129+
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
130+
131+
slices[3] = ::grpc::Slice(
132+
grpc_slice_new_with_user_data(
133+
const_cast<void*>(
134+
reinterpret_cast<const void*>(slr->rows().data())),
135+
rows_memory_size, [](void* backing) {},
136+
const_cast<char*>(
137+
reinterpret_cast<const char*>(slr->rows().data()))),
138+
::grpc::Slice::STEAL_REF);
139+
num_slices = 4;
140+
}
141+
142+
::grpc::ByteBuffer tmp(&slices[0], num_slices);
143+
msg->Swap(&tmp);
144+
}
145+
146+
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
147+
const platform::DeviceContext& ctx,
148+
const framework::Scope* scope,
149+
framework::Variable** var) {
150+
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
151+
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
152+
*var = resp.GetVar();
153+
}
154+
155+
} // namespace distributed
156+
} // namespace operators
157+
} // namespace paddle
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <sys/time.h>
17+
#include <iostream>
18+
#include <string>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/data_type.h"
22+
#include "paddle/fluid/framework/lod_tensor.h"
23+
#include "paddle/fluid/framework/scope.h"
24+
#include "paddle/fluid/framework/selected_rows.h"
25+
#include "paddle/fluid/framework/tensor_util.h"
26+
#include "paddle/fluid/framework/var_type.h"
27+
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
28+
29+
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
30+
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
31+
32+
namespace paddle {
33+
namespace operators {
34+
namespace distributed {
35+
36+
typedef void (*DestroyCallback)(void*);
37+
38+
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
39+
const platform::DeviceContext& ctx,
40+
::grpc::ByteBuffer* msg,
41+
const std::string& out_varname = std::string());
42+
43+
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
44+
const platform::DeviceContext& ctx,
45+
const framework::Scope* scope,
46+
framework::Variable** var);
47+
48+
} // namespace distributed
49+
} // namespace operators
50+
} // namespace paddle

paddle/fluid/operators/distributed/grpc_serde_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/lod_tensor.h"
2222
#include "paddle/fluid/framework/tensor_util.h"
2323
#include "paddle/fluid/framework/variable.h"
24+
#include "paddle/fluid/operators/detail/macros.h"
25+
#include "paddle/fluid/operators/distributed/grpc_serde.h"
26+
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
2427
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
25-
#include "paddle/fluid/operators/distributed/variable_response.h"
2628
#include "paddle/fluid/operators/math/math_function.h"
2729
#include "paddle/fluid/platform/place.h"
2830
#include "paddle/fluid/string/printf.h"
@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
8486
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
8587
framework::Scope scope;
8688
scope.Var("myvar");
87-
operators::distributed::VariableResponse resp(&scope, &ctx);
89+
operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
8890
EXPECT_EQ(resp.Parse(msg), 0);
8991

9092
framework::Variable* var2 = resp.GetVar();
@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
171173
// deserialize zero-copy
172174
framework::Scope scope;
173175
scope.Var("myvar");
174-
operators::distributed::VariableResponse resp(&scope, &ctx);
176+
operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
175177
if (from_type == 0) {
176178
EXPECT_EQ(resp.Parse(msg), 0);
177179
} else {

0 commit comments

Comments
 (0)