Skip to content

Commit 61343fb

Browse files
authored
Merge pull request #10531 from typhoonzero/refine_grpc_serde_code
Refine serde code
2 parents 6d371e4 + 796a448 commit 61343fb

File tree

4 files changed

+172
-178
lines changed

4 files changed

+172
-178
lines changed

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 105 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -29,129 +29,127 @@ namespace paddle {
2929
namespace operators {
3030
namespace detail {
3131

32+
using VarMsg = sendrecv::VariableMessage;
33+
34+
void GetTensorPayload(framework::Variable* var,
35+
const platform::DeviceContext& ctx, VarMsg* request,
36+
void** payload, size_t* payload_size) {
37+
auto tensor = var->Get<framework::LoDTensor>();
38+
// FIXME(wuyi): data types in send_recv.proto is copied from
39+
// framework.proto
40+
request->set_data_type(
41+
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
42+
for (auto& dim : framework::vectorize(tensor.dims())) {
43+
request->add_dims(dim);
44+
}
45+
const framework::LoD lod = tensor.lod();
46+
if (lod.size() > 0) {
47+
request->set_lod_level(lod.size());
48+
for (auto& each : lod) {
49+
VarMsg::LodData* lod_inner = request->add_lod();
50+
for (auto& d : each) {
51+
lod_inner->add_lod_data(d);
52+
}
53+
}
54+
}
55+
if (platform::is_gpu_place(ctx.GetPlace())) {
56+
#ifdef PADDLE_WITH_CUDA
57+
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
58+
platform::CPUPlace cpu;
59+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
60+
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
61+
*payload = memory::Alloc(cpu, copy_size);
62+
63+
memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
64+
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
65+
gpu_dev_ctx.stream());
66+
ctx.Wait();
67+
#endif
68+
} else {
69+
*payload = tensor.data<void>();
70+
}
71+
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
72+
}
73+
74+
void GetSelectedRowsPayload(framework::Variable* var,
75+
const platform::DeviceContext& ctx, VarMsg* request,
76+
void** payload, size_t* payload_size) {
77+
auto* slr = var->GetMutable<framework::SelectedRows>();
78+
request->set_data_type(
79+
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
80+
request->set_lod_level(0);
81+
request->set_slr_height(slr->height());
82+
83+
for (auto& dim : framework::vectorize(slr->value().dims())) {
84+
request->add_dims(dim);
85+
}
86+
87+
auto* tensor = slr->mutable_value();
88+
if (platform::is_gpu_place(ctx.GetPlace())) {
89+
#ifdef PADDLE_WITH_CUDA
90+
platform::CPUPlace cpu;
91+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
92+
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
93+
*payload = memory::Alloc(cpu, copy_size);
94+
memory::Copy(cpu, *payload,
95+
boost::get<platform::CUDAPlace>(tensor->place()),
96+
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
97+
gpu_dev_ctx.stream());
98+
ctx.Wait();
99+
#endif
100+
} else {
101+
*payload = slr->mutable_value()->data<void>();
102+
}
103+
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
104+
}
105+
32106
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
33107
const platform::DeviceContext& ctx,
34108
::grpc::ByteBuffer* msg,
35109
const std::string& out_name) {
36-
using VarMsg = sendrecv::VariableMessage;
37-
// When using GPU, need to free the copied CPU buffer
38-
// when the ByteBuffer destroies
39-
// TODO(typhoonzero): add unref here, if we have dependent
40-
// parallelism execution, need to know when to free the tensor.
110+
// Default DestroyCallback does nothing, When using GPU
111+
// the CPU buffer need to be freed.
41112
DestroyCallback destroy_callback = [](void* backing) {};
42-
43-
auto buffer = std::unique_ptr<char[]>(new char[1024]);
44-
void* buf = buffer.get();
45-
113+
VarMsg request;
46114
void* payload = nullptr;
47115
size_t payload_size;
48-
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
116+
117+
request.set_varname(name);
49118
// Note: normally the profiler is enabled in 1 trainer, hence only
50119
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
51120
// servers the trainer's profiling state so that PS can follow the
52121
// trainer.
53-
if (platform::ShouldSendProfileState()) {
54-
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
122+
request.set_profile(platform::IsProfileEnabled());
123+
if (!out_name.empty()) {
124+
request.set_out_varname(out_name);
55125
}
56-
e.WriteString(VarMsg::kVarnameFieldNumber, name);
57126
if (var->IsType<framework::LoDTensor>()) {
58-
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
127+
request.set_type(::sendrecv::LOD_TENSOR);
128+
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
59129
} else if (var->IsType<framework::SelectedRows>()) {
60-
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
130+
request.set_type(::sendrecv::SELECTED_ROWS);
131+
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
132+
} else {
133+
PADDLE_THROW("Serialize does not support type: %s",
134+
typeid(var->Type()).name());
61135
}
62136

63-
if (!out_name.empty()) {
64-
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
137+
if (platform::is_gpu_place(ctx.GetPlace())) {
138+
// GPU data is copied to CPU buffer when sending,
139+
// free the buffer when possible.
140+
destroy_callback = [](void* backing) {
141+
platform::CPUPlace cpu;
142+
memory::Free(cpu, backing);
143+
};
65144
}
66-
switch (framework::ToVarType(var->Type())) {
67-
case framework::proto::VarType_Type_LOD_TENSOR: {
68-
auto tensor = var->Get<framework::LoDTensor>();
69-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
70-
framework::ToDataType(tensor.type()));
71-
for (auto& dim : framework::vectorize(tensor.dims())) {
72-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
73-
}
74-
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
75-
if (lod.size() > 0) {
76-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
77-
78-
for (auto& each : lod) {
79-
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
80-
2 + // tag + varintlength of submessage
81-
1 + // kLodDataFieldNumber
82-
each.size());
83-
// auto copied from GPU
84-
for (auto& d : each) {
85-
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
86-
}
87-
}
88-
}
89-
if (platform::is_gpu_place(ctx.GetPlace())) {
90-
#ifdef PADDLE_WITH_CUDA
91-
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
92-
platform::CPUPlace cpu;
93-
auto& gpu_dev_ctx =
94-
static_cast<const platform::CUDADeviceContext&>(ctx);
95-
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
96-
payload = memory::Alloc(cpu, copy_size);
97-
98-
memory::Copy(cpu, payload,
99-
boost::get<platform::CUDAPlace>(tensor.place()),
100-
reinterpret_cast<const void*>(tensor.data<void>()),
101-
copy_size, gpu_dev_ctx.stream());
102-
ctx.Wait();
103-
destroy_callback = [](void* backing) {
104-
platform::CPUPlace cpu;
105-
memory::Free(cpu, backing);
106-
};
107145

108-
#endif
109-
} else {
110-
payload = tensor.data<void>();
111-
}
112-
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
113-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
114-
} break;
115-
case framework::proto::VarType_Type_SELECTED_ROWS: {
116-
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
117-
auto* slr = var->GetMutable<framework::SelectedRows>();
118-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
119-
framework::ToDataType(slr->value().type()));
120-
for (auto& dim : framework::vectorize(slr->value().dims())) {
121-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
122-
}
123-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
124-
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
125-
auto* tensor = slr->mutable_value();
126-
if (platform::is_gpu_place(ctx.GetPlace())) {
127-
#ifdef PADDLE_WITH_CUDA
128-
platform::CPUPlace cpu;
129-
auto& gpu_dev_ctx =
130-
static_cast<const platform::CUDADeviceContext&>(ctx);
131-
auto copy_size =
132-
tensor->numel() * framework::SizeOfType(tensor->type());
133-
payload = memory::Alloc(cpu, copy_size);
134-
memory::Copy(cpu, payload,
135-
boost::get<platform::CUDAPlace>(tensor->place()),
136-
reinterpret_cast<const void*>(tensor->data<void>()),
137-
copy_size, gpu_dev_ctx.stream());
138-
ctx.Wait();
139-
destroy_callback = [](void* backing) {
140-
platform::CPUPlace cpu;
141-
memory::Free(cpu, backing);
142-
};
143-
#endif
144-
} else {
145-
payload = slr->mutable_value()->data<void>();
146-
}
147-
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
148-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
149-
} break;
150-
default:
151-
PADDLE_THROW("Serialize does not support type: %s",
152-
typeid(var->Type()).name());
153-
break;
154-
}
146+
std::string header;
147+
request.AppendToString(&header);
148+
auto buffer = std::unique_ptr<char[]>(new char[1024]);
149+
void* buf = buffer.get();
150+
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
151+
e.WriteRawBytes(std::string(header.data(), header.size()));
152+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
155153
// steal reference of tensor data
156154
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
157155
int num_slices = 2; // only SelectedRows have rows buffer
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162160
static_cast<char*>(payload)),
163161
::grpc::Slice::STEAL_REF);
164162

165-
if (framework::ToVarType(var->Type()) ==
166-
framework::proto::VarType_Type_SELECTED_ROWS) {
163+
if (var->IsType<framework::SelectedRows>()) {
167164
auto* slr = var->GetMutable<framework::SelectedRows>();
168-
169165
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
170-
// NOTE: rows is of type int64_t
171166
size_t rows_memory_size =
172167
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
173168
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
178173
grpc_slice_new_with_user_data(
179174
const_cast<void*>(
180175
reinterpret_cast<const void*>(slr->rows().data())),
181-
rows_memory_size,
182-
[](void* backing) {
183-
// TODO(typhoonzero): add unref here, same as above.
184-
},
176+
rows_memory_size, [](void* backing) {},
185177
const_cast<char*>(
186178
reinterpret_cast<const char*>(slr->rows().data()))),
187179
::grpc::Slice::STEAL_REF);

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
117117
// serialize var to ByteBuffer
118118
framework::Variable var;
119119
auto* tensor = var.GetMutable<framework::LoDTensor>();
120-
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
120+
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
121121
framework::LoD lod;
122122
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
123123
tensor->set_lod(lod);
124-
int tensor_numel = 4 * 8 * 4 * 2;
124+
int tensor_numel = 512 * 8 * 4 * 2;
125125
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
126126
auto& ctx = *pool.Get(place);
127127
tensor->mutable_data<float>(place);
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
142142
EXPECT_TRUE(varmsg.ParseFromString(tmp));
143143
EXPECT_EQ(varmsg.varname(), "myvar");
144144
EXPECT_EQ(varmsg.type(), 0);
145-
EXPECT_EQ(varmsg.dims()[0], 4);
145+
EXPECT_EQ(varmsg.dims()[0], 512);
146146
EXPECT_EQ(varmsg.dims()[1], 8);
147147
EXPECT_EQ(varmsg.dims()[2], 4);
148148
EXPECT_EQ(varmsg.dims()[3], 2);

paddle/fluid/operators/detail/variable_response.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
210210
}
211211

212212
if (wt == WIRETYPE_LENGTH_DELIMITED) {
213-
int length = 0;
214-
if (!input->ReadVarintSizeAsInt(&length)) {
213+
int num_bytes = 0;
214+
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
215215
return tag;
216216
}
217-
218-
for (int i = 0; i < length; i++) {
217+
int start_pos = input->CurrentPosition();
218+
while (input->CurrentPosition() - start_pos < num_bytes) {
219219
uint64_t v;
220220
if (!input->ReadVarint64(&v)) {
221-
return false;
221+
return tag;
222222
}
223223
lod->push_back(v);
224224
}
@@ -275,17 +275,17 @@ int VariableResponse::Parse(Source* source) {
275275
break;
276276
}
277277
case sendrecv::VariableMessage::kTypeFieldNumber: {
278-
uint64_t v;
279-
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
278+
uint32_t v;
279+
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
280280
return tag;
281281
}
282282

283283
meta_.set_type(static_cast<::sendrecv::VarType>(v));
284284
break;
285285
}
286286
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
287-
uint64_t v = 0;
288-
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
287+
uint32_t v = 0;
288+
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
289289
return tag;
290290
}
291291

@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
305305

306306
// packed
307307
if (wt == WIRETYPE_LENGTH_DELIMITED) {
308-
int length = 0;
309-
if (!input.ReadVarintSizeAsInt(&length)) {
308+
int num_bytes = 0;
309+
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
310310
return tag;
311311
}
312-
for (int i = 0; i < length; i++) {
312+
int start_pos = input.CurrentPosition();
313+
while (input.CurrentPosition() - start_pos < num_bytes) {
313314
uint64_t v;
314315
if (!input.ReadVarint64(&v)) {
315316
return tag;
@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
318319
}
319320
break;
320321
}
321-
322322
return tag;
323323
}
324324
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
@@ -372,24 +372,24 @@ int VariableResponse::Parse(Source* source) {
372372
meta_.varname() != "",
373373
"meta info should be got first!");
374374

375-
int length = 0;
375+
int num_bytes = 0;
376376
if (wt != WIRETYPE_LENGTH_DELIMITED ||
377-
!ReadVarintSizeAsInt(&input, &length)) {
377+
!ReadVarintSizeAsInt(&input, &num_bytes)) {
378378
return tag;
379379
}
380380

381381
framework::DDim dims = GetDims(meta_.dims());
382382
if (meta_.type() == sendrecv::LOD_TENSOR) {
383383
PADDLE_ENFORCE(meta_.lod_size() >= 0,
384384
"lod info should be got first!");
385-
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) {
385+
if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) {
386386
return tag;
387387
}
388388
break;
389389
}
390390

391391
if (meta_.type() == sendrecv::SELECTED_ROWS) {
392-
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) {
392+
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) {
393393
return tag;
394394
}
395395
break;
@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
403403
meta_.varname() != "",
404404
"meta info should be got first!");
405405

406-
int length = 0;
406+
int num_bytes = 0;
407407
if (wt != WIRETYPE_LENGTH_DELIMITED ||
408-
!ReadVarintSizeAsInt(&input, &length)) {
408+
!ReadVarintSizeAsInt(&input, &num_bytes)) {
409409
return tag;
410410
}
411411

412-
if (!CopySelectRowsData(&input, *dev_ctx_, length)) {
412+
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
413413
return tag;
414414
}
415415
break;

0 commit comments

Comments
 (0)