@@ -29,129 +29,127 @@ namespace paddle {
29
29
namespace operators {
30
30
namespace detail {
31
31
32
+ using VarMsg = sendrecv::VariableMessage;
33
+
34
+ void GetTensorPayload (framework::Variable* var,
35
+ const platform::DeviceContext& ctx, VarMsg* request,
36
+ void ** payload, size_t * payload_size) {
37
+ auto tensor = var->Get <framework::LoDTensor>();
38
+ // FIXME(wuyi): data types in send_recv.proto is copied from
39
+ // framework.proto
40
+ request->set_data_type (
41
+ static_cast <VarMsg::Type>(framework::ToDataType (tensor.type ())));
42
+ for (auto & dim : framework::vectorize (tensor.dims ())) {
43
+ request->add_dims (dim);
44
+ }
45
+ const framework::LoD lod = tensor.lod ();
46
+ if (lod.size () > 0 ) {
47
+ request->set_lod_level (lod.size ());
48
+ for (auto & each : lod) {
49
+ VarMsg::LodData* lod_inner = request->add_lod ();
50
+ for (auto & d : each) {
51
+ lod_inner->add_lod_data (d);
52
+ }
53
+ }
54
+ }
55
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
56
+ #ifdef PADDLE_WITH_CUDA
57
+ PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
58
+ platform::CPUPlace cpu;
59
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
60
+ auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
61
+ *payload = memory::Alloc (cpu, copy_size);
62
+
63
+ memory::Copy (cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place ()),
64
+ reinterpret_cast <const void *>(tensor.data <void >()), copy_size,
65
+ gpu_dev_ctx.stream ());
66
+ ctx.Wait ();
67
+ #endif
68
+ } else {
69
+ *payload = tensor.data <void >();
70
+ }
71
+ *payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
72
+ }
73
+
74
+ void GetSelectedRowsPayload (framework::Variable* var,
75
+ const platform::DeviceContext& ctx, VarMsg* request,
76
+ void ** payload, size_t * payload_size) {
77
+ auto * slr = var->GetMutable <framework::SelectedRows>();
78
+ request->set_data_type (
79
+ static_cast <VarMsg::Type>(framework::ToDataType (slr->value ().type ())));
80
+ request->set_lod_level (0 );
81
+ request->set_slr_height (slr->height ());
82
+
83
+ for (auto & dim : framework::vectorize (slr->value ().dims ())) {
84
+ request->add_dims (dim);
85
+ }
86
+
87
+ auto * tensor = slr->mutable_value ();
88
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
89
+ #ifdef PADDLE_WITH_CUDA
90
+ platform::CPUPlace cpu;
91
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
92
+ auto copy_size = tensor->numel () * framework::SizeOfType (tensor->type ());
93
+ *payload = memory::Alloc (cpu, copy_size);
94
+ memory::Copy (cpu, *payload,
95
+ boost::get<platform::CUDAPlace>(tensor->place ()),
96
+ reinterpret_cast <const void *>(tensor->data <void >()), copy_size,
97
+ gpu_dev_ctx.stream ());
98
+ ctx.Wait ();
99
+ #endif
100
+ } else {
101
+ *payload = slr->mutable_value ()->data <void >();
102
+ }
103
+ *payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
104
+ }
105
+
32
106
void SerializeToByteBuffer (const std::string& name, framework::Variable* var,
33
107
const platform::DeviceContext& ctx,
34
108
::grpc::ByteBuffer* msg,
35
109
const std::string& out_name) {
36
- using VarMsg = sendrecv::VariableMessage;
37
- // When using GPU, need to free the copied CPU buffer
38
- // when the ByteBuffer destroies
39
- // TODO(typhoonzero): add unref here, if we have dependent
40
- // parallelism execution, need to know when to free the tensor.
110
+ // Default DestroyCallback does nothing, When using GPU
111
+ // the CPU buffer need to be freed.
41
112
DestroyCallback destroy_callback = [](void * backing) {};
42
-
43
- auto buffer = std::unique_ptr<char []>(new char [1024 ]);
44
- void * buf = buffer.get ();
45
-
113
+ VarMsg request;
46
114
void * payload = nullptr ;
47
115
size_t payload_size;
48
- ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
116
+
117
+ request.set_varname (name);
49
118
// Note: normally the profiler is enabled in 1 trainer, hence only
50
119
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
51
120
// servers the trainer's profiling state so that PS can follow the
52
121
// trainer.
53
- if (platform::ShouldSendProfileState ()) {
54
- e.WriteBool (VarMsg::kProfileFieldNumber , platform::IsProfileEnabled ());
122
+ request.set_profile (platform::IsProfileEnabled ());
123
+ if (!out_name.empty ()) {
124
+ request.set_out_varname (out_name);
55
125
}
56
- e.WriteString (VarMsg::kVarnameFieldNumber , name);
57
126
if (var->IsType <framework::LoDTensor>()) {
58
- e.WriteUint64 (VarMsg::kTypeFieldNumber , 0 );
127
+ request.set_type (::sendrecv::LOD_TENSOR);
128
+ GetTensorPayload (var, ctx, &request, &payload, &payload_size);
59
129
} else if (var->IsType <framework::SelectedRows>()) {
60
- e.WriteUint64 (VarMsg::kTypeFieldNumber , 1 );
130
+ request.set_type (::sendrecv::SELECTED_ROWS);
131
+ GetSelectedRowsPayload (var, ctx, &request, &payload, &payload_size);
132
+ } else {
133
+ PADDLE_THROW (" Serialize does not support type: %s" ,
134
+ typeid (var->Type ()).name ());
61
135
}
62
136
63
- if (!out_name.empty ()) {
64
- e.WriteString (VarMsg::kOutVarnameFieldNumber , out_name);
137
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
138
+ // GPU data is copied to CPU buffer when sending,
139
+ // free the buffer when possible.
140
+ destroy_callback = [](void * backing) {
141
+ platform::CPUPlace cpu;
142
+ memory::Free (cpu, backing);
143
+ };
65
144
}
66
- switch (framework::ToVarType (var->Type ())) {
67
- case framework::proto::VarType_Type_LOD_TENSOR: {
68
- auto tensor = var->Get <framework::LoDTensor>();
69
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
70
- framework::ToDataType (tensor.type ()));
71
- for (auto & dim : framework::vectorize (tensor.dims ())) {
72
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
73
- }
74
- auto lod = tensor.lod (); // std::vector<Vector<size_t>>
75
- if (lod.size () > 0 ) {
76
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , lod.size ());
77
-
78
- for (auto & each : lod) {
79
- e.WriteVarlengthBeginning (VarMsg::kLodFieldNumber ,
80
- 2 + // tag + varintlength of submessage
81
- 1 + // kLodDataFieldNumber
82
- each.size ());
83
- // auto copied from GPU
84
- for (auto & d : each) {
85
- e.WriteUint64 (VarMsg::LodData::kLodDataFieldNumber , d);
86
- }
87
- }
88
- }
89
- if (platform::is_gpu_place (ctx.GetPlace ())) {
90
- #ifdef PADDLE_WITH_CUDA
91
- PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
92
- platform::CPUPlace cpu;
93
- auto & gpu_dev_ctx =
94
- static_cast <const platform::CUDADeviceContext&>(ctx);
95
- auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
96
- payload = memory::Alloc (cpu, copy_size);
97
-
98
- memory::Copy (cpu, payload,
99
- boost::get<platform::CUDAPlace>(tensor.place ()),
100
- reinterpret_cast <const void *>(tensor.data <void >()),
101
- copy_size, gpu_dev_ctx.stream ());
102
- ctx.Wait ();
103
- destroy_callback = [](void * backing) {
104
- platform::CPUPlace cpu;
105
- memory::Free (cpu, backing);
106
- };
107
145
108
- #endif
109
- } else {
110
- payload = tensor.data <void >();
111
- }
112
- payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
113
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
114
- } break ;
115
- case framework::proto::VarType_Type_SELECTED_ROWS: {
116
- // TODO(typhoonzero): selectedrows implement should not use unique_ptr
117
- auto * slr = var->GetMutable <framework::SelectedRows>();
118
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
119
- framework::ToDataType (slr->value ().type ()));
120
- for (auto & dim : framework::vectorize (slr->value ().dims ())) {
121
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
122
- }
123
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , 0 );
124
- e.WriteUint64 (VarMsg::kSlrHeightFieldNumber , slr->height ());
125
- auto * tensor = slr->mutable_value ();
126
- if (platform::is_gpu_place (ctx.GetPlace ())) {
127
- #ifdef PADDLE_WITH_CUDA
128
- platform::CPUPlace cpu;
129
- auto & gpu_dev_ctx =
130
- static_cast <const platform::CUDADeviceContext&>(ctx);
131
- auto copy_size =
132
- tensor->numel () * framework::SizeOfType (tensor->type ());
133
- payload = memory::Alloc (cpu, copy_size);
134
- memory::Copy (cpu, payload,
135
- boost::get<platform::CUDAPlace>(tensor->place ()),
136
- reinterpret_cast <const void *>(tensor->data <void >()),
137
- copy_size, gpu_dev_ctx.stream ());
138
- ctx.Wait ();
139
- destroy_callback = [](void * backing) {
140
- platform::CPUPlace cpu;
141
- memory::Free (cpu, backing);
142
- };
143
- #endif
144
- } else {
145
- payload = slr->mutable_value ()->data <void >();
146
- }
147
- payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
148
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
149
- } break ;
150
- default :
151
- PADDLE_THROW (" Serialize does not support type: %s" ,
152
- typeid (var->Type ()).name ());
153
- break ;
154
- }
146
+ std::string header;
147
+ request.AppendToString (&header);
148
+ auto buffer = std::unique_ptr<char []>(new char [1024 ]);
149
+ void * buf = buffer.get ();
150
+ ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
151
+ e.WriteRawBytes (std::string (header.data (), header.size ()));
152
+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
155
153
// steal reference of tensor data
156
154
::grpc::Slice slices[4 ]; // metadata, tensor, rows meta, rows
157
155
int num_slices = 2 ; // only SelectedRows have rows buffer
@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
162
160
static_cast <char *>(payload)),
163
161
::grpc::Slice::STEAL_REF);
164
162
165
- if (framework::ToVarType (var->Type ()) ==
166
- framework::proto::VarType_Type_SELECTED_ROWS) {
163
+ if (var->IsType <framework::SelectedRows>()) {
167
164
auto * slr = var->GetMutable <framework::SelectedRows>();
168
-
169
165
ProtoEncodeHelper e2 (static_cast <char *>(buf), 128 );
170
- // NOTE: rows is of type int64_t
171
166
size_t rows_memory_size =
172
167
slr->rows ().size () * framework::SizeOfType (typeid (int64_t ));
173
168
e2 .WriteVarlengthBeginning (VarMsg::kRowsFieldNumber , rows_memory_size);
@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
178
173
grpc_slice_new_with_user_data (
179
174
const_cast <void *>(
180
175
reinterpret_cast <const void *>(slr->rows ().data ())),
181
- rows_memory_size,
182
- [](void * backing) {
183
- // TODO(typhoonzero): add unref here, same as above.
184
- },
176
+ rows_memory_size, [](void * backing) {},
185
177
const_cast <char *>(
186
178
reinterpret_cast <const char *>(slr->rows ().data ()))),
187
179
::grpc::Slice::STEAL_REF);
0 commit comments