Skip to content

Commit 711d86b

Browse files
committed
Polish data_type.h
1 parent ba57348 commit 711d86b

File tree

3 files changed

+88
-66
lines changed

3 files changed

+88
-66
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto)
55
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
66
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
77
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
8-
8+
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim)
99
if(WITH_GPU)
10-
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto)
10+
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory device_context data_type)
1111
else()
12-
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto)
12+
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory device_context data_type)
1313
endif()
1414

1515
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)

paddle/fluid/framework/data_type.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/data_type.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
struct DataTypeMap {
21+
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
22+
std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_;
23+
std::unordered_map<proto::VarType::Type, std::string> proto_to_str_;
24+
};
25+
26+
static DataTypeMap g_data_type_map_;
27+
28+
template <typename T>
29+
static inline void RegisterType(proto::VarType::Type proto_type,
30+
const std::string &name) {
31+
g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T));
32+
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
33+
g_data_type_map_.proto_to_str_.emplace(proto_type, name);
34+
}
35+
36+
static int RegisterAllTypes() {
37+
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
38+
39+
RegType(platform::float16, proto::VarType::FP16);
40+
RegType(float, proto::VarType::FP32);
41+
RegType(double, proto::VarType::FP64);
42+
RegType(int, proto::VarType::INT32);
43+
RegType(int64_t, proto::VarType::INT64);
44+
RegType(bool, proto::VarType::BOOL);
45+
46+
#undef RegType
47+
return 0;
48+
}
49+
50+
static std::once_flag register_once_flag_;
51+
52+
proto::VarType::Type ToDataType(std::type_index type) {
53+
std::call_once(register_once_flag_, RegisterAllTypes);
54+
auto it = g_data_type_map_.cpp_to_proto_.find(type);
55+
if (it != g_data_type_map_.cpp_to_proto_.end()) {
56+
return it->second;
57+
}
58+
PADDLE_THROW("Not support %s as tensor type", type.name());
59+
}
60+
61+
std::type_index ToTypeIndex(proto::VarType::Type type) {
62+
std::call_once(register_once_flag_, RegisterAllTypes);
63+
auto it = g_data_type_map_.proto_to_cpp_.find(type);
64+
if (it != g_data_type_map_.proto_to_cpp_.end()) {
65+
return it->second;
66+
}
67+
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
68+
static_cast<int>(type));
69+
}
70+
71+
std::string DataTypeToString(const proto::VarType::Type type) {
72+
std::call_once(register_once_flag_, RegisterAllTypes);
73+
auto it = g_data_type_map_.proto_to_str_.find(type);
74+
if (it != g_data_type_map_.proto_to_str_.end()) {
75+
return it->second;
76+
}
77+
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
78+
static_cast<int>(type));
79+
}
80+
81+
} // namespace framework
82+
} // namespace paddle

paddle/fluid/framework/data_type.h

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,8 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace framework {
2424

25-
inline proto::VarType::Type ToDataType(std::type_index type) {
26-
if (typeid(platform::float16).hash_code() == type.hash_code()) {
27-
return proto::VarType::FP16;
28-
} else if (typeid(const float).hash_code() == type.hash_code()) {
29-
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
30-
// One fix to this is to replace float with const float because
31-
// typeid(T) == typeid(const T)
32-
// http://en.cppreference.com/w/cpp/language/typeid
33-
return proto::VarType::FP32;
34-
} else if (typeid(const double).hash_code() == type.hash_code()) {
35-
return proto::VarType::FP64;
36-
} else if (typeid(const int).hash_code() == type.hash_code()) {
37-
return proto::VarType::INT32;
38-
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
39-
return proto::VarType::INT64;
40-
} else if (typeid(const bool).hash_code() == type.hash_code()) {
41-
return proto::VarType::BOOL;
42-
} else {
43-
PADDLE_THROW("Not supported");
44-
}
45-
}
46-
47-
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
48-
switch (type) {
49-
case proto::VarType::FP16:
50-
return typeid(platform::float16);
51-
case proto::VarType::FP32:
52-
return typeid(float);
53-
case proto::VarType::FP64:
54-
return typeid(double);
55-
case proto::VarType::INT32:
56-
return typeid(int);
57-
case proto::VarType::INT64:
58-
return typeid(int64_t);
59-
case proto::VarType::BOOL:
60-
return typeid(bool);
61-
default:
62-
PADDLE_THROW("Not support type %d", type);
63-
}
64-
}
65-
25+
extern proto::VarType::Type ToDataType(std::type_index type);
26+
extern std::type_index ToTypeIndex(proto::VarType::Type type);
6627
template <typename Visitor>
6728
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
6829
switch (type) {
@@ -89,32 +50,11 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
8950
}
9051
}
9152

92-
inline std::string DataTypeToString(const proto::VarType::Type type) {
93-
switch (type) {
94-
case proto::VarType::FP16:
95-
return "float16";
96-
case proto::VarType::FP32:
97-
return "float32";
98-
case proto::VarType::FP64:
99-
return "float64";
100-
case proto::VarType::INT16:
101-
return "int16";
102-
case proto::VarType::INT32:
103-
return "int32";
104-
case proto::VarType::INT64:
105-
return "int64";
106-
case proto::VarType::BOOL:
107-
return "bool";
108-
default:
109-
PADDLE_THROW("Not support type %d", type);
110-
}
111-
}
112-
53+
extern std::string DataTypeToString(const proto::VarType::Type type);
11354
inline std::ostream& operator<<(std::ostream& out,
11455
const proto::VarType::Type& type) {
11556
out << DataTypeToString(type);
11657
return out;
11758
}
118-
11959
} // namespace framework
12060
} // namespace paddle

0 commit comments

Comments
 (0)