@@ -53,109 +53,106 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
53
53
e.WriteUint64 (VarMsg::kTypeFieldNumber , 1 );
54
54
} else if (var->IsType <ncclUniqueId>()) {
55
55
// NOTE: sendrecv only support RAW type for NCCL_ID
56
+ VLOG (3 ) << " serilizing: setting var type nccl id" ;
56
57
e.WriteUint64 (VarMsg::kTypeFieldNumber , 2 );
57
58
}
58
59
59
60
if (!out_name.empty ()) {
60
61
e.WriteString (VarMsg::kOutVarnameFieldNumber , out_name);
61
62
}
62
- switch (framework::ToVarType (var->Type ())) {
63
- case framework::proto::VarType_Type_LOD_TENSOR: {
64
- auto tensor = var->Get <framework::LoDTensor>();
65
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
66
- framework::ToDataType (tensor.type ()));
67
- for (auto & dim : framework::vectorize (tensor.dims ())) {
68
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
69
- }
70
- auto lod = tensor.lod (); // std::vector<Vector<size_t>>
71
- if (lod.size () > 0 ) {
72
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , lod.size ());
73
-
74
- for (auto & each : lod) {
75
- e.WriteVarlengthBeginning (VarMsg::kLodFieldNumber ,
76
- 2 + // tag + varintlength of submessage
77
- 1 + // kLodDataFieldNumber
78
- each.size ());
79
- // auto copied from GPU
80
- for (auto & d : each) {
81
- e.WriteUint64 (VarMsg::LodData::kLodDataFieldNumber , d);
82
- }
63
+ if (var->IsType <framework::LoDTensor>()) {
64
+ // ===========================Tensor==================================
65
+ auto tensor = var->Get <framework::LoDTensor>();
66
+ e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
67
+ framework::ToDataType (tensor.type ()));
68
+ for (auto & dim : framework::vectorize (tensor.dims ())) {
69
+ e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
70
+ }
71
+ auto lod = tensor.lod (); // std::vector<Vector<size_t>>
72
+ if (lod.size () > 0 ) {
73
+ e.WriteUint64 (VarMsg::kLodLevelFieldNumber , lod.size ());
74
+
75
+ for (auto & each : lod) {
76
+ e.WriteVarlengthBeginning (VarMsg::kLodFieldNumber ,
77
+ 2 + // tag + varintlength of submessage
78
+ 1 + // kLodDataFieldNumber
79
+ each.size ());
80
+ // auto copied from GPU
81
+ for (auto & d : each) {
82
+ e.WriteUint64 (VarMsg::LodData::kLodDataFieldNumber , d);
83
83
}
84
84
}
85
- if (platform::is_gpu_place (ctx.GetPlace ())) {
85
+ }
86
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
86
87
#ifdef PADDLE_WITH_CUDA
87
- PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
88
+ PADDLE_ENFORCE (platform::is_gpu_place (tensor.place ()));
89
+ platform::CPUPlace cpu;
90
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
91
+ auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
92
+ payload = memory::Alloc (cpu, copy_size);
93
+
94
+ memory::Copy (cpu, payload,
95
+ boost::get<platform::CUDAPlace>(tensor.place ()),
96
+ reinterpret_cast <const void *>(tensor.data <void >()),
97
+ copy_size, gpu_dev_ctx.stream ());
98
+ ctx.Wait ();
99
+ destroy_callback = [](void * backing) {
88
100
platform::CPUPlace cpu;
89
- auto & gpu_dev_ctx =
90
- static_cast <const platform::CUDADeviceContext&>(ctx);
91
- auto copy_size = tensor.numel () * framework::SizeOfType (tensor.type ());
92
- payload = memory::Alloc (cpu, copy_size);
93
-
94
- memory::Copy (cpu, payload,
95
- boost::get<platform::CUDAPlace>(tensor.place ()),
96
- reinterpret_cast <const void *>(tensor.data <void >()),
97
- copy_size, gpu_dev_ctx.stream ());
98
- ctx.Wait ();
99
- destroy_callback = [](void * backing) {
100
- platform::CPUPlace cpu;
101
- memory::Free (cpu, backing);
102
- };
101
+ memory::Free (cpu, backing);
102
+ };
103
103
104
104
#endif
105
- } else {
106
- payload = tensor.data <void >();
107
- }
108
- payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
109
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
110
- } break ;
111
- case framework::proto::VarType_Type_SELECTED_ROWS: {
112
- // TODO(typhoonzero): selectedrows implement should not use unique_ptr
113
- auto * slr = var->GetMutable <framework::SelectedRows>();
114
- e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
115
- framework::ToDataType (slr->value ().type ()));
116
- for (auto & dim : framework::vectorize (slr->value ().dims ())) {
117
- e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
118
- }
119
- e.WriteUint64 (VarMsg::kLodLevelFieldNumber , 0 );
120
- e.WriteUint64 (VarMsg::kSlrHeightFieldNumber , slr->height ());
121
- auto * tensor = slr->mutable_value ();
122
- if (platform::is_gpu_place (ctx.GetPlace ())) {
105
+ } else {
106
+ payload = tensor.data <void >();
107
+ }
108
+ payload_size = tensor.numel () * framework::SizeOfType (tensor.type ());
109
+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
110
+ } else if (var->IsType <framework::SelectedRows>()) {
111
+ // ===========================SELECTED
112
+ // ROWS==================================
113
+ // TODO(typhoonzero): selectedrows implement should not use unique_ptr
114
+ auto * slr = var->GetMutable <framework::SelectedRows>();
115
+ e.WriteUint64 (VarMsg::kDataTypeFieldNumber ,
116
+ framework::ToDataType (slr->value ().type ()));
117
+ for (auto & dim : framework::vectorize (slr->value ().dims ())) {
118
+ e.WriteUint64 (VarMsg::kDimsFieldNumber , dim);
119
+ }
120
+ e.WriteUint64 (VarMsg::kLodLevelFieldNumber , 0 );
121
+ e.WriteUint64 (VarMsg::kSlrHeightFieldNumber , slr->height ());
122
+ auto * tensor = slr->mutable_value ();
123
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
123
124
#ifdef PADDLE_WITH_CUDA
125
+ platform::CPUPlace cpu;
126
+ auto & gpu_dev_ctx = static_cast <const platform::CUDADeviceContext&>(ctx);
127
+ auto copy_size = tensor->numel () * framework::SizeOfType (tensor->type ());
128
+ payload = memory::Alloc (cpu, copy_size);
129
+ memory::Copy (cpu, payload,
130
+ boost::get<platform::CUDAPlace>(tensor->place ()),
131
+ reinterpret_cast <const void *>(tensor->data <void >()),
132
+ copy_size, gpu_dev_ctx.stream ());
133
+ ctx.Wait ();
134
+ destroy_callback = [](void * backing) {
124
135
platform::CPUPlace cpu;
125
- auto & gpu_dev_ctx =
126
- static_cast <const platform::CUDADeviceContext&>(ctx);
127
- auto copy_size =
128
- tensor->numel () * framework::SizeOfType (tensor->type ());
129
- payload = memory::Alloc (cpu, copy_size);
130
- memory::Copy (cpu, payload,
131
- boost::get<platform::CUDAPlace>(tensor->place ()),
132
- reinterpret_cast <const void *>(tensor->data <void >()),
133
- copy_size, gpu_dev_ctx.stream ());
134
- ctx.Wait ();
135
- destroy_callback = [](void * backing) {
136
- platform::CPUPlace cpu;
137
- memory::Free (cpu, backing);
138
- };
136
+ memory::Free (cpu, backing);
137
+ };
139
138
#endif
140
- } else {
141
- payload = slr->mutable_value ()->data <void >();
142
- }
143
- payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
144
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
145
- } break ;
146
- case framework::proto::VarType_Type_RAW: {
147
- e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber ,
148
- NCCL_UNIQUE_ID_BYTES);
149
- ncclUniqueId* uid = var->GetMutable <ncclUniqueId>();
150
- e.WriteRawBytes (std::string (uid->internal , NCCL_UNIQUE_ID_BYTES));
151
- } break ;
152
- default :
153
- PADDLE_THROW (" Serialize does not support type: %s" ,
154
- typeid (var->Type ()).name ());
155
- break ;
139
+ } else {
140
+ payload = slr->mutable_value ()->data <void >();
141
+ }
142
+ payload_size = tensor->numel () * framework::SizeOfType (tensor->type ());
143
+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
144
+ } else if (var->IsType <ncclUniqueId>()) {
145
+ // ===========================NCCL ID==================================
146
+ e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber ,
147
+ NCCL_UNIQUE_ID_BYTES);
148
+ ncclUniqueId* uid = var->GetMutable <ncclUniqueId>();
149
+ e.WriteRawBytes (std::string (uid->internal , NCCL_UNIQUE_ID_BYTES));
150
+ } else {
151
+ PADDLE_THROW (" Serialize does not support type: %s" ,
152
+ typeid (var->Type ()).name ());
156
153
}
157
154
158
- if (framework::ToVarType ( var->Type ()) == framework::proto::VarType_Type_RAW ) {
155
+ if (var->IsType <ncclUniqueId>() ) {
159
156
// for serialize NCCL_ID
160
157
::grpc::Slice slices (e.size ());
161
158
memcpy (const_cast <uint8_t *>(slices.begin ()), e.data (), e.size ());
0 commit comments