@@ -31,35 +31,14 @@ namespace detail {
31
31
32
32
using VarMsg = sendrecv::VariableMessage;
33
33
34
- VarMsg::Type DataTypeToEnum (std::type_index type) {
35
- if (typeid (platform::float16).hash_code () == type.hash_code ()) {
36
- return VarMsg::FP16;
37
- } else if (typeid (const float ).hash_code () == type.hash_code ()) {
38
- // CPPLint complains Using C-style cast. Use static_cast<float>() instead
39
- // One fix to this is to replace float with const float because
40
- // typeid(T) == typeid(const T)
41
- // http://en.cppreference.com/w/cpp/language/typeid
42
- return VarMsg::FP32;
43
- } else if (typeid (const double ).hash_code () == type.hash_code ()) {
44
- return VarMsg::FP64;
45
- } else if (typeid (const int ).hash_code () == type.hash_code ()) {
46
- return VarMsg::INT32;
47
- } else if (typeid (const int64_t ).hash_code () == type.hash_code ()) {
48
- return VarMsg::INT64;
49
- } else if (typeid (const bool ).hash_code () == type.hash_code ()) {
50
- return VarMsg::BOOL;
51
- } else {
52
- PADDLE_THROW (" Not supported" );
53
- }
54
- }
55
-
56
34
void GetTensorPayload (framework::Variable* var,
57
35
const platform::DeviceContext& ctx, VarMsg* request,
58
36
void ** payload, size_t * payload_size) {
59
37
auto tensor = var->Get <framework::LoDTensor>();
60
- // FIXME(wuyi): data types in send_recv.proto is not synced with
38
+ // FIXME(wuyi): data types in send_recv.proto is copied from
61
39
// framework.proto
62
- request->set_data_type (DataTypeToEnum (tensor.type ()));
40
+ request->set_data_type (
41
+ static_cast <VarMsg::Type>(framework::ToDataType (tensor.type ())));
63
42
for (auto & dim : framework::vectorize (tensor.dims ())) {
64
43
request->add_dims (dim);
65
44
}
@@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var,
96
75
const platform::DeviceContext& ctx, VarMsg* request,
97
76
void ** payload, size_t * payload_size) {
98
77
auto * slr = var->GetMutable <framework::SelectedRows>();
99
- request->set_data_type (DataTypeToEnum (slr->value ().type ()));
78
+ request->set_data_type (
79
+ static_cast <VarMsg::Type>(framework::ToDataType (slr->value ().type ())));
100
80
request->set_lod_level (0 );
101
81
request->set_slr_height (slr->height ());
102
82
@@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
170
150
ProtoEncodeHelper e (static_cast <char *>(buf), 1024 );
171
151
e.WriteRawBytes (std::string (header.data (), header.size ()));
172
152
e.WriteVarlengthBeginning (VarMsg::kSerializedFieldNumber , payload_size);
173
-
174
153
// steal reference of tensor data
175
154
::grpc::Slice slices[4 ]; // metadata, tensor, rows meta, rows
176
155
int num_slices = 2 ; // only SelectedRows have rows buffer
0 commit comments