Skip to content

Commit 5f6fd26

Browse files
authored
Merge pull request #10572 from reyoung/feature/polish_visit_data_type
Polish data_type.h
2 parents 30c350b + c70ddb0 commit 5f6fd26

File tree

6 files changed

+114
-108
lines changed

6 files changed

+114
-108
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto)
55
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
66
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
77
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
8-
8+
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
99
if(WITH_GPU)
10-
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto)
10+
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type)
1111
else()
12-
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto)
12+
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type)
1313
endif()
1414

1515
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)

paddle/fluid/framework/data_type.cc

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/data_type.h"
16+
#include <stdint.h>
17+
#include <string>
18+
#include <unordered_map>
19+
20+
namespace paddle {
21+
namespace framework {
22+
23+
struct DataTypeMap {
24+
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
25+
std::unordered_map<int, std::type_index> proto_to_cpp_;
26+
std::unordered_map<int, std::string> proto_to_str_;
27+
std::unordered_map<std::type_index, size_t> cpp_to_size_;
28+
};
29+
30+
static DataTypeMap* InitDataTypeMap();
31+
static DataTypeMap& gDataTypeMap() {
32+
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
33+
return *g_data_type_map_;
34+
}
35+
36+
template <typename T>
37+
static inline void RegisterType(DataTypeMap* map,
38+
proto::VarType::Type proto_type,
39+
const std::string& name) {
40+
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
41+
map->cpp_to_proto_.emplace(typeid(T), proto_type);
42+
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
43+
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
44+
}
45+
46+
static DataTypeMap* InitDataTypeMap() {
47+
auto retv = new DataTypeMap();
48+
49+
#define RegType(cc_type, proto_type) \
50+
RegisterType<cc_type>(retv, proto_type, #cc_type)
51+
52+
// NOTE: Add your customize type here.
53+
RegType(platform::float16, proto::VarType::FP16);
54+
RegType(float, proto::VarType::FP32);
55+
RegType(double, proto::VarType::FP64);
56+
RegType(int, proto::VarType::INT32);
57+
RegType(int64_t, proto::VarType::INT64);
58+
RegType(bool, proto::VarType::BOOL);
59+
RegType(size_t, proto::VarType::SIZE_T);
60+
RegType(int16_t, proto::VarType::INT16);
61+
62+
#undef RegType
63+
return retv;
64+
}
65+
66+
proto::VarType::Type ToDataType(std::type_index type) {
67+
auto it = gDataTypeMap().cpp_to_proto_.find(type);
68+
if (it != gDataTypeMap().cpp_to_proto_.end()) {
69+
return it->second;
70+
}
71+
PADDLE_THROW("Not support %s as tensor type", type.name());
72+
}
73+
74+
std::type_index ToTypeIndex(proto::VarType::Type type) {
75+
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
76+
if (it != gDataTypeMap().proto_to_cpp_.end()) {
77+
return it->second;
78+
}
79+
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
80+
static_cast<int>(type));
81+
}
82+
83+
std::string DataTypeToString(const proto::VarType::Type type) {
84+
auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
85+
if (it != gDataTypeMap().proto_to_str_.end()) {
86+
return it->second;
87+
}
88+
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
89+
static_cast<int>(type));
90+
}
91+
92+
size_t SizeOfType(std::type_index type) {
93+
auto it = gDataTypeMap().cpp_to_size_.find(type);
94+
if (it != gDataTypeMap().cpp_to_size_.end()) {
95+
return it->second;
96+
}
97+
PADDLE_THROW("Not support %s as tensor type", type.name());
98+
}
99+
100+
} // namespace framework
101+
} // namespace paddle

paddle/fluid/framework/data_type.h

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,51 +17,14 @@ limitations under the License. */
1717
#include <typeindex>
1818
#include "paddle/fluid/framework/framework.pb.h"
1919
#include "paddle/fluid/platform/enforce.h"
20+
2021
#include "paddle/fluid/platform/float16.h"
2122

2223
namespace paddle {
2324
namespace framework {
2425

25-
inline proto::VarType::Type ToDataType(std::type_index type) {
26-
if (typeid(platform::float16).hash_code() == type.hash_code()) {
27-
return proto::VarType::FP16;
28-
} else if (typeid(const float).hash_code() == type.hash_code()) {
29-
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
30-
// One fix to this is to replace float with const float because
31-
// typeid(T) == typeid(const T)
32-
// http://en.cppreference.com/w/cpp/language/typeid
33-
return proto::VarType::FP32;
34-
} else if (typeid(const double).hash_code() == type.hash_code()) {
35-
return proto::VarType::FP64;
36-
} else if (typeid(const int).hash_code() == type.hash_code()) {
37-
return proto::VarType::INT32;
38-
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
39-
return proto::VarType::INT64;
40-
} else if (typeid(const bool).hash_code() == type.hash_code()) {
41-
return proto::VarType::BOOL;
42-
} else {
43-
PADDLE_THROW("Not supported");
44-
}
45-
}
46-
47-
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
48-
switch (type) {
49-
case proto::VarType::FP16:
50-
return typeid(platform::float16);
51-
case proto::VarType::FP32:
52-
return typeid(float);
53-
case proto::VarType::FP64:
54-
return typeid(double);
55-
case proto::VarType::INT32:
56-
return typeid(int);
57-
case proto::VarType::INT64:
58-
return typeid(int64_t);
59-
case proto::VarType::BOOL:
60-
return typeid(bool);
61-
default:
62-
PADDLE_THROW("Not support type %d", type);
63-
}
64-
}
26+
extern proto::VarType::Type ToDataType(std::type_index type);
27+
extern std::type_index ToTypeIndex(proto::VarType::Type type);
6528

6629
template <typename Visitor>
6730
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
@@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
8952
}
9053
}
9154

92-
inline std::string DataTypeToString(const proto::VarType::Type type) {
93-
switch (type) {
94-
case proto::VarType::FP16:
95-
return "float16";
96-
case proto::VarType::FP32:
97-
return "float32";
98-
case proto::VarType::FP64:
99-
return "float64";
100-
case proto::VarType::INT16:
101-
return "int16";
102-
case proto::VarType::INT32:
103-
return "int32";
104-
case proto::VarType::INT64:
105-
return "int64";
106-
case proto::VarType::BOOL:
107-
return "bool";
108-
default:
109-
PADDLE_THROW("Not support type %d", type);
110-
}
111-
}
112-
55+
extern std::string DataTypeToString(const proto::VarType::Type type);
56+
extern size_t SizeOfType(std::type_index type);
11357
inline std::ostream& operator<<(std::ostream& out,
11458
const proto::VarType::Type& type) {
11559
out << DataTypeToString(type);
11660
return out;
11761
}
118-
11962
} // namespace framework
12063
} // namespace paddle

paddle/fluid/framework/framework.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ message VarType {
101101
FP16 = 4;
102102
FP32 = 5;
103103
FP64 = 6;
104+
// Tensor<size_t> is used in C++.
105+
SIZE_T = 19;
104106

105107
// Other types that may need additional descriptions
106108
LOD_TENSOR = 7;

paddle/fluid/framework/op_kernel_type_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) {
2727
LibraryType::kCUDNN);
2828

2929
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
30-
"data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type["
30+
"data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
3131
"CUDNN]");
3232
}
3333

paddle/fluid/framework/tensor_impl.h

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include "paddle/fluid/framework/data_type.h"
1617
#include "paddle/fluid/memory/memcpy.h"
1718
#include "paddle/fluid/platform/enforce.h"
1819
#include "paddle/fluid/platform/float16.h"
1920

2021
namespace paddle {
2122
namespace framework {
22-
23-
template <typename... T>
24-
struct SizeOfTypeFunctor;
25-
26-
template <typename T>
27-
struct SizeOfTypeFunctor<T> {
28-
size_t operator()(std::type_index type) const {
29-
if (typeid(T).hash_code() == type.hash_code()) {
30-
return sizeof(T);
31-
} else {
32-
return 0UL;
33-
}
34-
}
35-
};
36-
37-
template <>
38-
struct SizeOfTypeFunctor<> {
39-
size_t operator()(std::type_index type) const { return 0UL; }
40-
};
41-
42-
template <typename HEAD, typename... TAIL>
43-
struct SizeOfTypeFunctor<HEAD, TAIL...> {
44-
size_t operator()(std::type_index type) const {
45-
SizeOfTypeFunctor<HEAD> head;
46-
size_t head_size = head(type);
47-
if (head_size != 0) {
48-
return head_size;
49-
}
50-
SizeOfTypeFunctor<TAIL...> tail;
51-
return tail(type);
52-
}
53-
};
54-
55-
static inline size_t SizeOfType(std::type_index type) {
56-
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
57-
platform::float16>
58-
functor;
59-
size_t size = functor(type);
60-
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
61-
return size;
62-
}
63-
23+
extern size_t SizeOfType(std::type_index type);
6424
inline void Tensor::check_memory_size() const {
6525
PADDLE_ENFORCE_NOT_NULL(
6626
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");

0 commit comments

Comments
 (0)