Skip to content

Commit a2de156

Browse files
committed
refine serde code
1 parent 9a98a57 commit a2de156

File tree

4 files changed

+193
-178
lines changed

4 files changed

+193
-178
lines changed

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 126 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -29,129 +29,148 @@ namespace paddle {
2929
namespace operators {
3030
namespace detail {
3131

32+
using VarMsg = sendrecv::VariableMessage;
33+
34+
VarMsg::Type DataTypeToEnum(std::type_index type) {
35+
if (typeid(platform::float16).hash_code() == type.hash_code()) {
36+
return VarMsg::FP16;
37+
} else if (typeid(const float).hash_code() == type.hash_code()) {
38+
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
39+
// One fix to this is to replace float with const float because
40+
// typeid(T) == typeid(const T)
41+
// http://en.cppreference.com/w/cpp/language/typeid
42+
return VarMsg::FP32;
43+
} else if (typeid(const double).hash_code() == type.hash_code()) {
44+
return VarMsg::FP64;
45+
} else if (typeid(const int).hash_code() == type.hash_code()) {
46+
return VarMsg::INT32;
47+
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
48+
return VarMsg::INT64;
49+
} else if (typeid(const bool).hash_code() == type.hash_code()) {
50+
return VarMsg::BOOL;
51+
} else {
52+
PADDLE_THROW("Not supported");
53+
}
54+
}
55+
56+
void GetTensorPayload(framework::Variable* var,
57+
const platform::DeviceContext& ctx, VarMsg* request,
58+
void** payload, size_t* payload_size) {
59+
auto tensor = var->Get<framework::LoDTensor>();
60+
// FIXME(wuyi): data types in send_recv.proto is not synced with
61+
// framework.proto
62+
request->set_data_type(DataTypeToEnum(tensor.type()));
63+
for (auto& dim : framework::vectorize(tensor.dims())) {
64+
request->add_dims(dim);
65+
}
66+
const framework::LoD lod = tensor.lod();
67+
if (lod.size() > 0) {
68+
request->set_lod_level(lod.size());
69+
for (auto& each : lod) {
70+
VarMsg::LodData* lod_inner = request->add_lod();
71+
for (auto& d : each) {
72+
lod_inner->add_lod_data(d);
73+
}
74+
}
75+
}
76+
if (platform::is_gpu_place(ctx.GetPlace())) {
77+
#ifdef PADDLE_WITH_CUDA
78+
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
79+
platform::CPUPlace cpu;
80+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
81+
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
82+
*payload = memory::Alloc(cpu, copy_size);
83+
84+
memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
85+
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
86+
gpu_dev_ctx.stream());
87+
ctx.Wait();
88+
#endif
89+
} else {
90+
*payload = tensor.data<void>();
91+
}
92+
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
93+
}
94+
95+
void GetSelectedRowsPayload(framework::Variable* var,
96+
const platform::DeviceContext& ctx, VarMsg* request,
97+
void** payload, size_t* payload_size) {
98+
auto* slr = var->GetMutable<framework::SelectedRows>();
99+
request->set_data_type(DataTypeToEnum(slr->value().type()));
100+
request->set_lod_level(0);
101+
request->set_slr_height(slr->height());
102+
103+
for (auto& dim : framework::vectorize(slr->value().dims())) {
104+
request->add_dims(dim);
105+
}
106+
107+
auto* tensor = slr->mutable_value();
108+
if (platform::is_gpu_place(ctx.GetPlace())) {
109+
#ifdef PADDLE_WITH_CUDA
110+
platform::CPUPlace cpu;
111+
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
112+
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
113+
*payload = memory::Alloc(cpu, copy_size);
114+
memory::Copy(cpu, *payload,
115+
boost::get<platform::CUDAPlace>(tensor->place()),
116+
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
117+
gpu_dev_ctx.stream());
118+
ctx.Wait();
119+
#endif
120+
} else {
121+
*payload = slr->mutable_value()->data<void>();
122+
}
123+
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
124+
}
125+
32126
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
33127
const platform::DeviceContext& ctx,
34128
::grpc::ByteBuffer* msg,
35129
const std::string& out_name) {
36-
using VarMsg = sendrecv::VariableMessage;
37-
// When using GPU, need to free the copied CPU buffer
38-
// when the ByteBuffer destroies
39-
// TODO(typhoonzero): add unref here, if we have dependent
40-
// parallelism execution, need to know when to free the tensor.
130+
// Default DestroyCallback does nothing, When using GPU
131+
// the CPU buffer need to be freed.
41132
DestroyCallback destroy_callback = [](void* backing) {};
42-
43-
auto buffer = std::unique_ptr<char[]>(new char[1024]);
44-
void* buf = buffer.get();
45-
133+
VarMsg request;
46134
void* payload = nullptr;
47135
size_t payload_size;
48-
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
136+
137+
request.set_varname(name);
49138
// Note: normally the profiler is enabled in 1 trainer, hence only
50139
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
51140
// servers the trainer's profiling state so that PS can follow the
52141
// trainer.
53-
if (platform::ShouldSendProfileState()) {
54-
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
142+
request.set_profile(platform::IsProfileEnabled());
143+
if (!out_name.empty()) {
144+
request.set_out_varname(out_name);
55145
}
56-
e.WriteString(VarMsg::kVarnameFieldNumber, name);
57146
if (var->IsType<framework::LoDTensor>()) {
58-
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
147+
request.set_type(::sendrecv::LOD_TENSOR);
148+
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
59149
} else if (var->IsType<framework::SelectedRows>()) {
60-
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
150+
request.set_type(::sendrecv::SELECTED_ROWS);
151+
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
152+
} else {
153+
PADDLE_THROW("Serialize does not support type: %s",
154+
typeid(var->Type()).name());
61155
}
62156

63-
if (!out_name.empty()) {
64-
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
157+
if (platform::is_gpu_place(ctx.GetPlace())) {
158+
// GPU data is copied to CPU buffer when sending,
159+
// free the buffer when possible.
160+
destroy_callback = [](void* backing) {
161+
platform::CPUPlace cpu;
162+
memory::Free(cpu, backing);
163+
};
65164
}
66-
switch (framework::ToVarType(var->Type())) {
67-
case framework::proto::VarType_Type_LOD_TENSOR: {
68-
auto tensor = var->Get<framework::LoDTensor>();
69-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
70-
framework::ToDataType(tensor.type()));
71-
for (auto& dim : framework::vectorize(tensor.dims())) {
72-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
73-
}
74-
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
75-
if (lod.size() > 0) {
76-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
77-
78-
for (auto& each : lod) {
79-
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
80-
2 + // tag + varintlength of submessage
81-
1 + // kLodDataFieldNumber
82-
each.size());
83-
// auto copied from GPU
84-
for (auto& d : each) {
85-
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
86-
}
87-
}
88-
}
89-
if (platform::is_gpu_place(ctx.GetPlace())) {
90-
#ifdef PADDLE_WITH_CUDA
91-
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
92-
platform::CPUPlace cpu;
93-
auto& gpu_dev_ctx =
94-
static_cast<const platform::CUDADeviceContext&>(ctx);
95-
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
96-
payload = memory::Alloc(cpu, copy_size);
97-
98-
memory::Copy(cpu, payload,
99-
boost::get<platform::CUDAPlace>(tensor.place()),
100-
reinterpret_cast<const void*>(tensor.data<void>()),
101-
copy_size, gpu_dev_ctx.stream());
102-
ctx.Wait();
103-
destroy_callback = [](void* backing) {
104-
platform::CPUPlace cpu;
105-
memory::Free(cpu, backing);
106-
};
107165

108-
#endif
109-
} else {
110-
payload = tensor.data<void>();
111-
}
112-
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
113-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
114-
} break;
115-
case framework::proto::VarType_Type_SELECTED_ROWS: {
116-
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
117-
auto* slr = var->GetMutable<framework::SelectedRows>();
118-
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
119-
framework::ToDataType(slr->value().type()));
120-
for (auto& dim : framework::vectorize(slr->value().dims())) {
121-
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
122-
}
123-
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
124-
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
125-
auto* tensor = slr->mutable_value();
126-
if (platform::is_gpu_place(ctx.GetPlace())) {
127-
#ifdef PADDLE_WITH_CUDA
128-
platform::CPUPlace cpu;
129-
auto& gpu_dev_ctx =
130-
static_cast<const platform::CUDADeviceContext&>(ctx);
131-
auto copy_size =
132-
tensor->numel() * framework::SizeOfType(tensor->type());
133-
payload = memory::Alloc(cpu, copy_size);
134-
memory::Copy(cpu, payload,
135-
boost::get<platform::CUDAPlace>(tensor->place()),
136-
reinterpret_cast<const void*>(tensor->data<void>()),
137-
copy_size, gpu_dev_ctx.stream());
138-
ctx.Wait();
139-
destroy_callback = [](void* backing) {
140-
platform::CPUPlace cpu;
141-
memory::Free(cpu, backing);
142-
};
143-
#endif
144-
} else {
145-
payload = slr->mutable_value()->data<void>();
146-
}
147-
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
148-
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
149-
} break;
150-
default:
151-
PADDLE_THROW("Serialize does not support type: %s",
152-
typeid(var->Type()).name());
153-
break;
154-
}
166+
std::string header;
167+
request.AppendToString(&header);
168+
auto buffer = std::unique_ptr<char[]>(new char[1024]);
169+
void* buf = buffer.get();
170+
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
171+
e.WriteRawBytes(std::string(header.data(), header.size()));
172+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
173+
155174
// steal reference of tensor data
156175
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
157176
int num_slices = 2; // only SelectedRows have rows buffer
@@ -162,12 +181,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162181
static_cast<char*>(payload)),
163182
::grpc::Slice::STEAL_REF);
164183

165-
if (framework::ToVarType(var->Type()) ==
166-
framework::proto::VarType_Type_SELECTED_ROWS) {
184+
if (var->IsType<framework::SelectedRows>()) {
167185
auto* slr = var->GetMutable<framework::SelectedRows>();
168-
169186
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
170-
// NOTE: rows is of type int64_t
171187
size_t rows_memory_size =
172188
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
173189
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
@@ -178,10 +194,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
178194
grpc_slice_new_with_user_data(
179195
const_cast<void*>(
180196
reinterpret_cast<const void*>(slr->rows().data())),
181-
rows_memory_size,
182-
[](void* backing) {
183-
// TODO(typhoonzero): add unref here, same as above.
184-
},
197+
rows_memory_size, [](void* backing) {},
185198
const_cast<char*>(
186199
reinterpret_cast<const char*>(slr->rows().data()))),
187200
::grpc::Slice::STEAL_REF);

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
117117
// serialize var to ByteBuffer
118118
framework::Variable var;
119119
auto* tensor = var.GetMutable<framework::LoDTensor>();
120-
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
120+
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
121121
framework::LoD lod;
122122
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
123123
tensor->set_lod(lod);
124-
int tensor_numel = 4 * 8 * 4 * 2;
124+
int tensor_numel = 512 * 8 * 4 * 2;
125125
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
126126
auto& ctx = *pool.Get(place);
127127
tensor->mutable_data<float>(place);
@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
142142
EXPECT_TRUE(varmsg.ParseFromString(tmp));
143143
EXPECT_EQ(varmsg.varname(), "myvar");
144144
EXPECT_EQ(varmsg.type(), 0);
145-
EXPECT_EQ(varmsg.dims()[0], 4);
145+
EXPECT_EQ(varmsg.dims()[0], 512);
146146
EXPECT_EQ(varmsg.dims()[1], 8);
147147
EXPECT_EQ(varmsg.dims()[2], 4);
148148
EXPECT_EQ(varmsg.dims()[3], 2);

0 commit comments

Comments
 (0)