Skip to content

Commit 45af8c1

Browse files
typhoonzerogongweibao
authored andcommitted
Performance/zero copy variable seriralization (#8839)
1 parent 12fc76e commit 45af8c1

File tree

9 files changed

+786
-4
lines changed

9 files changed

+786
-4
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
187187

188188
void TensorToStream(std::ostream& os, const Tensor& tensor,
189189
const platform::DeviceContext& dev_ctx) {
190-
// TODO(typhoonzero): serialize to ostream
191190
{ // the 1st field, uint32_t version
192191
constexpr uint32_t version = 0;
193192
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
if(WITH_DISTRIBUTE)
2-
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
2+
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
3+
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
4+
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
5+
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
36
endif()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
// NOTE: This file was originally created by tensorflow
16+
// (https://github.com/tensorflow/tensorflow/) we borrow this
17+
// file and did some modifications so that we can send gRPC
18+
// requests without too much copying of the tensor data.
19+
20+
#include "bytebuffer_stream.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
namespace detail {
25+
26+
GrpcByteBufferSource::GrpcByteBufferSource() {}
27+
28+
bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
29+
cur_ = -1;
30+
left_ = 0;
31+
ptr_ = nullptr;
32+
byte_count_ = 0;
33+
bool ok = src.Dump(&slices_).ok();
34+
if (!ok) {
35+
slices_.clear();
36+
}
37+
return ok;
38+
}
39+
40+
bool GrpcByteBufferSource::Next(const void** data, int* size) {
41+
// Use loop instead of if in case buffer contained empty slices.
42+
while (left_ == 0) {
43+
// Advance to next slice.
44+
cur_++;
45+
if (cur_ >= slices_.size()) {
46+
return false;
47+
}
48+
const ::grpc::Slice& s = slices_[cur_];
49+
left_ = s.size();
50+
ptr_ = reinterpret_cast<const char*>(s.begin());
51+
}
52+
53+
*data = ptr_;
54+
*size = left_;
55+
byte_count_ += left_;
56+
ptr_ += left_;
57+
left_ = 0;
58+
return true;
59+
}
60+
61+
void GrpcByteBufferSource::BackUp(int count) {
62+
ptr_ -= count;
63+
left_ += count;
64+
byte_count_ -= count;
65+
}
66+
67+
bool GrpcByteBufferSource::Skip(int count) {
68+
const void* data;
69+
int size;
70+
while (Next(&data, &size)) {
71+
if (size >= count) {
72+
BackUp(size - count);
73+
return true;
74+
}
75+
// size < count;
76+
count -= size;
77+
}
78+
// error or we have too large count;
79+
return false;
80+
}
81+
82+
google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
83+
return byte_count_;
84+
}
85+
86+
} // namespace detail
87+
} // namespace operators
88+
} // namespace paddle
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
// NOTE: This file was originally created by tensorflow
16+
// (https://github.com/tensorflow/tensorflow/) we borrow this
17+
// file and did some modifications so that we can send gRPC
18+
// requests without too much copying of the tensor data.
19+
20+
#pragma once
21+
22+
#include <grpc++/grpc++.h>
23+
#include "google/protobuf/io/coded_stream.h"
24+
#include "google/protobuf/io/zero_copy_stream.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
namespace detail {
29+
30+
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
31+
class GrpcByteBufferSource
32+
: public ::google::protobuf::io::ZeroCopyInputStream {
33+
public:
34+
GrpcByteBufferSource();
35+
bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times.
36+
bool Next(const void** data, int* size) override;
37+
void BackUp(int count) override;
38+
bool Skip(int count) override;
39+
::google::protobuf::int64 ByteCount() const override;
40+
41+
private:
42+
std::vector<::grpc::Slice> slices_;
43+
size_t cur_; // Current slice index.
44+
int left_; // Number of bytes in slices_[cur_] left to yield.
45+
const char* ptr_; // Address of next byte in slices_[cur_] to yield.
46+
::google::protobuf::int64 byte_count_;
47+
};
48+
49+
} // namespace detail
50+
} // namespace operators
51+
} // namespace paddle
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
// NOTE: This file was originally created by tensorflow
16+
// (https://github.com/tensorflow/tensorflow/) we borrow this
17+
// file and did some modifications so that we can send gRPC
18+
// requests without too much copying of the tensor data.
19+
20+
#pragma once
21+
22+
#include <grpc++/grpc++.h>
23+
#include "paddle/fluid/platform/enforce.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
namespace detail {
28+
29+
char* EncodeVarint32(char* dst, uint32_t v) {
30+
// Operate on characters as unsigneds
31+
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
32+
static const int B = 128;
33+
if (v < (1 << 7)) {
34+
*(ptr++) = v;
35+
} else if (v < (1 << 14)) {
36+
*(ptr++) = v | B;
37+
*(ptr++) = v >> 7;
38+
} else if (v < (1 << 21)) {
39+
*(ptr++) = v | B;
40+
*(ptr++) = (v >> 7) | B;
41+
*(ptr++) = v >> 14;
42+
} else if (v < (1 << 28)) {
43+
*(ptr++) = v | B;
44+
*(ptr++) = (v >> 7) | B;
45+
*(ptr++) = (v >> 14) | B;
46+
*(ptr++) = v >> 21;
47+
} else {
48+
*(ptr++) = v | B;
49+
*(ptr++) = (v >> 7) | B;
50+
*(ptr++) = (v >> 14) | B;
51+
*(ptr++) = (v >> 21) | B;
52+
*(ptr++) = v >> 28;
53+
}
54+
return reinterpret_cast<char*>(ptr);
55+
}
56+
57+
char* EncodeVarint64(char* dst, uint64_t v) {
58+
static const int B = 128;
59+
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
60+
while (v >= B) {
61+
*(ptr++) = (v & (B - 1)) | B;
62+
v >>= 7;
63+
}
64+
*(ptr++) = static_cast<unsigned char>(v);
65+
return reinterpret_cast<char*>(ptr);
66+
}
67+
68+
int VarintLength(uint64_t v) {
69+
int len = 1;
70+
while (v >= 128) {
71+
v >>= 7;
72+
len++;
73+
}
74+
return len;
75+
}
76+
77+
class ProtoEncodeHelper {
78+
public:
79+
ProtoEncodeHelper(char* buf, int max_size)
80+
: base_(buf), p_(buf), limit_(base_ + max_size) {}
81+
82+
~ProtoEncodeHelper() {
83+
// Make sure callers didn't do operations that went over max_size promised
84+
PADDLE_ENFORCE_LE(p_, limit_);
85+
}
86+
87+
const char* data() const { return base_; }
88+
size_t size() const { return p_ - base_; }
89+
90+
void WriteUint64(int tag, uint64_t v) {
91+
Encode32(combine(tag, WIRETYPE_VARINT));
92+
Encode64(v);
93+
}
94+
void WriteBool(int tag, bool v) {
95+
Encode32(combine(tag, WIRETYPE_VARINT));
96+
EncodeBool(v);
97+
}
98+
void WriteString(int tag, const std::string& v) {
99+
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
100+
Encode32(v.size());
101+
EncodeBytes(v.data(), v.size());
102+
}
103+
void WriteVarlengthBeginning(int tag, uint32_t len) {
104+
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
105+
Encode32(len);
106+
}
107+
void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); }
108+
109+
private:
110+
// Note: this module's behavior must match the protocol buffer wire encoding
111+
// format.
112+
enum {
113+
WIRETYPE_VARINT = 0,
114+
WIRETYPE_LENGTH_DELIMITED = 2,
115+
};
116+
static uint32_t combine(uint32_t tag, uint32_t type) {
117+
return ((tag << 3) | type);
118+
}
119+
inline void Encode32(uint32_t v) {
120+
if (v < 128) {
121+
// Fast path for single-byte values. Many of the calls will use a
122+
// constant value for v, so the comparison will get optimized away
123+
// when Encode32 is inlined into the caller.
124+
*p_ = v;
125+
p_++;
126+
} else {
127+
p_ = EncodeVarint32(p_, v);
128+
}
129+
}
130+
void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); }
131+
void EncodeBool(bool v) {
132+
*p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1
133+
p_++;
134+
}
135+
void EncodeBytes(const char* bytes, int N) {
136+
memcpy(p_, bytes, N);
137+
p_ += N;
138+
}
139+
140+
char* base_;
141+
char* p_;
142+
char* limit_; // Just for CHECKs
143+
};
144+
145+
} // detail
146+
} // operators
147+
} // paddle

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,34 @@ enum VarType {
3333
}
3434

3535
message VariableMessage {
36+
enum Type {
37+
// Pod Types
38+
BOOL = 0;
39+
INT16 = 1;
40+
INT32 = 2;
41+
INT64 = 3;
42+
FP16 = 4;
43+
FP32 = 5;
44+
FP64 = 6;
45+
}
46+
47+
message LodData { repeated int64 lod_data = 1; }
48+
3649
string varname = 1;
3750
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
3851
VarType type = 2;
39-
bytes serialized = 3;
52+
// bool persistable is not needed for sending.
53+
// tensor info:
54+
Type data_type = 3;
55+
repeated int64 dims = 4;
56+
57+
// lod details:
58+
int64 lod_level = 5;
59+
repeated LodData lod = 6;
60+
// tensor data
61+
bytes serialized = 7;
62+
// selected_rows data
63+
bytes rows = 8;
4064
}
4165

4266
message VoidMessage {}

0 commit comments

Comments
 (0)