Skip to content

Commit 602aa43

Browse files
committed
cast data type
1 parent a2de156 commit 602aa43

File tree

1 file changed

+5
-26
lines changed

1 file changed

+5
-26
lines changed

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,14 @@ namespace detail {
3131

3232
using VarMsg = sendrecv::VariableMessage;
3333

34-
VarMsg::Type DataTypeToEnum(std::type_index type) {
35-
if (typeid(platform::float16).hash_code() == type.hash_code()) {
36-
return VarMsg::FP16;
37-
} else if (typeid(const float).hash_code() == type.hash_code()) {
38-
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
39-
// One fix to this is to replace float with const float because
40-
// typeid(T) == typeid(const T)
41-
// http://en.cppreference.com/w/cpp/language/typeid
42-
return VarMsg::FP32;
43-
} else if (typeid(const double).hash_code() == type.hash_code()) {
44-
return VarMsg::FP64;
45-
} else if (typeid(const int).hash_code() == type.hash_code()) {
46-
return VarMsg::INT32;
47-
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
48-
return VarMsg::INT64;
49-
} else if (typeid(const bool).hash_code() == type.hash_code()) {
50-
return VarMsg::BOOL;
51-
} else {
52-
PADDLE_THROW("Not supported");
53-
}
54-
}
55-
5634
void GetTensorPayload(framework::Variable* var,
5735
const platform::DeviceContext& ctx, VarMsg* request,
5836
void** payload, size_t* payload_size) {
5937
auto tensor = var->Get<framework::LoDTensor>();
60-
// FIXME(wuyi): data types in send_recv.proto is not synced with
38+
// FIXME(wuyi): data types in send_recv.proto is copied from
6139
// framework.proto
62-
request->set_data_type(DataTypeToEnum(tensor.type()));
40+
request->set_data_type(
41+
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
6342
for (auto& dim : framework::vectorize(tensor.dims())) {
6443
request->add_dims(dim);
6544
}
@@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var,
9675
const platform::DeviceContext& ctx, VarMsg* request,
9776
void** payload, size_t* payload_size) {
9877
auto* slr = var->GetMutable<framework::SelectedRows>();
99-
request->set_data_type(DataTypeToEnum(slr->value().type()));
78+
request->set_data_type(
79+
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
10080
request->set_lod_level(0);
10181
request->set_slr_height(slr->height());
10282

@@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
170150
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
171151
e.WriteRawBytes(std::string(header.data(), header.size()));
172152
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
173-
174153
// steal reference of tensor data
175154
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
176155
int num_slices = 2; // only SelectedRows have rows buffer

0 commit comments

Comments
 (0)