@@ -29,129 +29,148 @@ namespace paddle {
29
29
namespace operators {
30
30
namespace detail {
31
31
32
+ using VarMsg = sendrecv::VariableMessage;
33
+
34
+ VarMsg::Type DataTypeToEnum (std::type_index type) {
35
+ if (typeid (platform::float16).hash_code () == type.hash_code ()) {
36
+ return VarMsg::FP16;
37
+ } else if (typeid (const float ).hash_code () == type.hash_code ()) {
38
+ // CPPLint complains Using C-style cast. Use static_cast<float>() instead
39
+ // One fix to this is to replace float with const float because
40
+ // typeid(T) == typeid(const T)
41
+ // http://en.cppreference.com/w/cpp/language/typeid
42
+ return VarMsg::FP32;
43
+ } else if (typeid (const double ).hash_code () == type.hash_code ()) {
44
+ return VarMsg::FP64;
45
+ } else if (typeid (const int ).hash_code () == type.hash_code ()) {
46
+ return VarMsg::INT32;
47
+ } else if (typeid (const int64_t ).hash_code () == type.hash_code ()) {
48
+ return VarMsg::INT64;
49
+ } else if (typeid (const bool ).hash_code () == type.hash_code ()) {
50
+ return VarMsg::BOOL;
51
+ } else {
52
+ PADDLE_THROW (" Not supported" );
53
+ }
54
+ }
55
+
56
+ void GetTensorPayload (framework::Variable* var,
57
+ const platform::DeviceContext& ctx, VarMsg* request,
58
+ void ** payload, size_t * payload_size) {
59
+ auto tensor = var->Get <framework::LoDTensor>();
60
+ // FIXME(wuyi): data types in send_recv.proto is not synced with
61
+ // framework.proto
62
+ request->set_data_type (DataTypeToEnum (tensor.type ()));
63
+ for (auto & dim : framework::vectorize (tensor.dims ())) {
64
+ request->add_dims (dim);
65
+ }
66
+ const framework::LoD lod = tensor.lod ();
67
+ if (lod.size () > 0 ) {
68
+ request->set_lod_level (lod.size ());
69
+ for (auto & each : lod) {
70
+ VarMsg::LodData* lod_inner = request->add_lod ();
71
+ for (auto & d : each) {
72
+ lod_inner->add_lod_data (d);
73
+ }
74
+ }
75
+ }
76
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
77
+ #ifdef PADDLE_WITH_CUDA
78
+ PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
79
+ platform::CPUPlace cpu;
80
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
81
+ auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
82
+ *payload = memory::Alloc (cpu, copy_size);
83
+
84
+ memory::Copy (cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place ()),
85
+ reinterpret_cast <const void *>(tensor.data <void >()), copy_size,
86
+ gpu_dev_ctx.stream ());
87
+ ctx.Wait ();
88
+ #endif
89
+ } else {
90
+ *payload = tensor.data <void >();
91
+ }
92
+ *payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
93
+ }
94
+
95
+ void GetSelectedRowsPayload (framework::Variable* var,
96
+ const platform::DeviceContext& ctx, VarMsg* request,
97
+ void ** payload, size_t * payload_size) {
98
+ auto * slr = var->GetMutable <framework::SelectedRows>();
99
+ request->set_data_type (DataTypeToEnum (slr->value ().type ()));
100
+ request->set_lod_level (0 );
101
+ request->set_slr_height (slr->height ());
102
+
103
+ for (auto & dim : framework::vectorize (slr->value ().dims ())) {
104
+ request->add_dims (dim);
105
+ }
106
+
107
+ auto * tensor = slr->mutable_value ();
108
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
109
+ #ifdef PADDLE_WITH_CUDA
110
+ platform::CPUPlace cpu;
111
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
112
+ auto copy_size = tensor->numel () * framework::SizeOfType (tensor->type ());
113
+ *payload = memory::Alloc (cpu, copy_size);
114
+ memory::Copy (cpu, *payload,
115
+ boost::get<platform::CUDAPlace>(tensor->place ()),
116
+ reinterpret_cast <const void *>(tensor->data <void >()), copy_size,
117
+ gpu_dev_ctx.stream ());
118
+ ctx.Wait ();
119
+ #endif
120
+ } else {
121
+ *payload = slr->mutable_value ()->data <void >();
122
+ }
123
+ *payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
124
+ }
125
+
32
126
void SerializeToByteBuffer (const std::string& name, framework::Variable* var,
33
127
const platform::DeviceContext& ctx,
34
128
::grpc::ByteBuffer* msg,
35
129
const std::string& out_name) {
36
- using VarMsg = sendrecv::VariableMessage;
37
- // When using GPU, need to free the copied CPU buffer
38
- // when the ByteBuffer destroies
39
- // TODO(typhoonzero): add unref here, if we have dependent
40
- // parallelism execution, need to know when to free the tensor.
130
+ // Default DestroyCallback does nothing, When using GPU
131
+ // the CPU buffer need to be freed.
41
132
DestroyCallback destroy_callback = [](void * backing) {};
42
-
43
- auto buffer = std::unique_ptr<char []>(new char [1024 ]);
44
- void * buf = buffer.get ();
45
-
133
+ VarMsg request;
46
134
void * payload = nullptr ;
47
135
size_t payload_size;
48
- ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
136
+
137
+ request.set_varname (name);
49
138
// Note: normally the profiler is enabled in 1 trainer, hence only
50
139
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
51
140
// servers the trainer's profiling state so that PS can follow the
52
141
// trainer.
53
- if (platform::ShouldSendProfileState ()) {
54
- e.WriteBool (VarMsg::kProfileFieldNumber , platform::IsProfileEnabled ());
142
+ request.set_profile (platform::IsProfileEnabled ());
143
+ if (!out_name.empty ()) {
144
+ request.set_out_varname (out_name);
55
145
}
56
- e.WriteString (VarMsg::kVarnameFieldNumber , name);
57
146
if (var->IsType <framework::LoDTensor>()) {
58
- e.WriteUint64 (VarMsg::kTypeFieldNumber , 0 );
147
+ request.set_type (::sendrecv::LOD_TENSOR);
148
+ GetTensorPayload (var, ctx, &request, &payload, &payload_size);
59
149
} else if (var->IsType <framework::SelectedRows>()) {
60
- e.WriteUint64 (VarMsg::kTypeFieldNumber , 1 );
150
+ request.set_type (::sendrecv::SELECTED_ROWS);
151
+ GetSelectedRowsPayload (var, ctx, &request, &payload, &payload_size);
152
+ } else {
153
+ PADDLE_THROW (" Serialize does not support type: %s" ,
154
+ typeid (var->Type ()).name ());
61
155
}
62
156
63
- if (!out_name.empty ()) {
64
- e.WriteString (VarMsg::kOutVarnameFieldNumber , out_name);
157
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
158
+ // GPU data is copied to CPU buffer when sending,
159
+ // free the buffer when possible.
160
+ destroy_callback = [](void * backing) {
161
+ platform::CPUPlace cpu;
162
+ memory::Free (cpu, backing);
163
+ };
65
164
}
66
- switch (framework::ToVarType (var->Type ())) {
67
- case framework::proto::VarType_Type_LOD_TENSOR: {
68
- auto tensor = var->Get <framework::LoDTensor>();
69
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
70
- framework::ToDataType (tensor.type ()));
71
- for (auto & dim : framework::vectorize (tensor.dims ())) {
72
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
73
- }
74
- auto lod = tensor.lod (); // std::vector<Vector<size_t>>
75
- if (lod.size () > 0 ) {
76
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , lod.size ());
77
-
78
- for (auto & each : lod) {
79
- e.WriteVarlengthBeginning (VarMsg::kLodFieldNumber ,
80
- 2 + // tag + varintlength of submessage
81
- 1 + // kLodDataFieldNumber
82
- each.size ());
83
- // auto copied from GPU
84
- for (auto & d : each) {
85
- e.WriteUint64 (VarMsg::LodData::kLodDataFieldNumber , d);
86
- }
87
- }
88
- }
89
- if (platform::is_gpu_place (ctx.GetPlace ())) {
90
- #ifdef PADDLE_WITH_CUDA
91
- PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
92
- platform::CPUPlace cpu;
93
- auto & gpu_dev_ctx =
94
- static_cast <const platform::CUDADeviceContext&>(ctx);
95
- auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
96
- payload = memory::Alloc (cpu, copy_size);
97
-
98
- memory::Copy (cpu, payload,
99
- boost::get<platform::CUDAPlace>(tensor.place ()),
100
- reinterpret_cast <const void *>(tensor.data <void >()),
101
- copy_size, gpu_dev_ctx.stream ());
102
- ctx.Wait ();
103
- destroy_callback = [](void * backing) {
104
- platform::CPUPlace cpu;
105
- memory::Free (cpu, backing);
106
- };
107
165
108
- #endif
109
- } else {
110
- payload = tensor.data <void >();
111
- }
112
- payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
113
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
114
- } break ;
115
- case framework::proto::VarType_Type_SELECTED_ROWS: {
116
- // TODO(typhoonzero): selectedrows implement should not use unique_ptr
117
- auto * slr = var->GetMutable <framework::SelectedRows>();
118
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
119
- framework::ToDataType (slr->value ().type ()));
120
- for (auto & dim : framework::vectorize (slr->value ().dims ())) {
121
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
122
- }
123
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , 0 );
124
- e.WriteUint64 (VarMsg::kSlrHeightFieldNumber , slr->height ());
125
- auto * tensor = slr->mutable_value ();
126
- if (platform::is_gpu_place (ctx.GetPlace ())) {
127
- #ifdef PADDLE_WITH_CUDA
128
- platform::CPUPlace cpu;
129
- auto & gpu_dev_ctx =
130
- static_cast <const platform::CUDADeviceContext&>(ctx);
131
- auto copy_size =
132
- tensor->numel () * framework::SizeOfType (tensor->type ());
133
- payload = memory::Alloc (cpu, copy_size);
134
- memory::Copy (cpu, payload,
135
- boost::get<platform::CUDAPlace>(tensor->place ()),
136
- reinterpret_cast <const void *>(tensor->data <void >()),
137
- copy_size, gpu_dev_ctx.stream ());
138
- ctx.Wait ();
139
- destroy_callback = [](void * backing) {
140
- platform::CPUPlace cpu;
141
- memory::Free (cpu, backing);
142
- };
143
- #endif
144
- } else {
145
- payload = slr->mutable_value ()->data <void >();
146
- }
147
- payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
148
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
149
- } break ;
150
- default :
151
- PADDLE_THROW (" Serialize does not support type: %s" ,
152
- typeid (var->Type ()).name ());
153
- break ;
154
- }
166
+ std::string header;
167
+ request.AppendToString (&header);
168
+ auto buffer = std::unique_ptr<char []>(new char [1024 ]);
169
+ void * buf = buffer.get ();
170
+ ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
171
+ e.WriteRawBytes (std::string (header.data (), header.size ()));
172
+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
173
+
155
174
// steal reference of tensor data
156
175
::grpc::Slice slices[4 ]; // metadata, tensor, rows meta, rows
157
176
int num_slices = 2 ; // only SelectedRows have rows buffer
@@ -162,12 +181,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162
181
static_cast <char *>(payload)),
163
182
::grpc::Slice::STEAL_REF);
164
183
165
- if (framework::ToVarType (var->Type ()) ==
166
- framework::proto::VarType_Type_SELECTED_ROWS) {
184
+ if (var->IsType <framework::SelectedRows>()) {
167
185
auto * slr = var->GetMutable <framework::SelectedRows>();
168
-
169
186
ProtoEncodeHelper e2 (static_cast <char *>(buf), 128 );
170
- // NOTE: rows is of type int64_t
171
187
size_t rows_memory_size =
172
188
slr->rows ().size () * framework::SizeOfType (typeid (int64_t ));
173
189
e2 .WriteVarlengthBeginning (VarMsg::kRowsFieldNumber , rows_memory_size);
@@ -178,10 +194,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
178
194
grpc_slice_new_with_user_data (
179
195
const_cast <void *>(
180
196
reinterpret_cast <const void *>(slr->rows ().data ())),
181
- rows_memory_size,
182
- [](void * backing) {
183
- // TODO(typhoonzero): add unref here, same as above.
184
- },
197
+ rows_memory_size, [](void * backing) {},
185
198
const_cast <char *>(
186
199
reinterpret_cast <const char *>(slr->rows ().data ()))),
187
200
::grpc::Slice::STEAL_REF);
0 commit comments