Skip to content

Commit c4d6daa

Browse files
committed
Polish SizeOfType
1 parent 711d86b commit c4d6daa

File tree

3 files changed

+17
-42
lines changed

3 files changed

+17
-42
lines changed

paddle/fluid/framework/data_type.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct DataTypeMap {
2121
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
2222
std::unordered_map<proto::VarType::Type, std::type_index> proto_to_cpp_;
2323
std::unordered_map<proto::VarType::Type, std::string> proto_to_str_;
24+
std::unordered_map<std::type_index, size_t> cpp_to_size_;
2425
};
2526

2627
static DataTypeMap g_data_type_map_;
@@ -31,11 +32,13 @@ static inline void RegisterType(proto::VarType::Type proto_type,
3132
g_data_type_map_.proto_to_cpp_.emplace(proto_type, typeid(T));
3233
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
3334
g_data_type_map_.proto_to_str_.emplace(proto_type, name);
35+
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
3436
}
3537

3638
static int RegisterAllTypes() {
3739
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
3840

41+
// NOTE: Add your customize type here.
3942
RegType(platform::float16, proto::VarType::FP16);
4043
RegType(float, proto::VarType::FP32);
4144
RegType(double, proto::VarType::FP64);
@@ -78,5 +81,14 @@ std::string DataTypeToString(const proto::VarType::Type type) {
7881
static_cast<int>(type));
7982
}
8083

84+
size_t SizeOfType(std::type_index type) {
85+
std::call_once(register_once_flag_, RegisterAllTypes);
86+
auto it = g_data_type_map_.cpp_to_size_.find(type);
87+
if (it != g_data_type_map_.cpp_to_size_.end()) {
88+
return it->second;
89+
}
90+
PADDLE_THROW("Not support %s as tensor type", type.name());
91+
}
92+
8193
} // namespace framework
8294
} // namespace paddle

paddle/fluid/framework/data_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ limitations under the License. */
1717
#include <typeindex>
1818
#include "paddle/fluid/framework/framework.pb.h"
1919
#include "paddle/fluid/platform/enforce.h"
20+
2021
#include "paddle/fluid/platform/float16.h"
2122

2223
namespace paddle {
2324
namespace framework {
2425

2526
extern proto::VarType::Type ToDataType(std::type_index type);
2627
extern std::type_index ToTypeIndex(proto::VarType::Type type);
28+
2729
template <typename Visitor>
2830
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
2931
switch (type) {
@@ -51,6 +53,7 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
5153
}
5254

5355
extern std::string DataTypeToString(const proto::VarType::Type type);
56+
extern size_t SizeOfType(std::type_index type);
5457
inline std::ostream& operator<<(std::ostream& out,
5558
const proto::VarType::Type& type) {
5659
out << DataTypeToString(type);

paddle/fluid/framework/tensor_impl.h

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include "paddle/fluid/framework/data_type.h"
1617
#include "paddle/fluid/memory/memcpy.h"
1718
#include "paddle/fluid/platform/enforce.h"
1819
#include "paddle/fluid/platform/float16.h"
1920

2021
namespace paddle {
2122
namespace framework {
22-
23-
template <typename... T>
24-
struct SizeOfTypeFunctor;
25-
26-
template <typename T>
27-
struct SizeOfTypeFunctor<T> {
28-
size_t operator()(std::type_index type) const {
29-
if (typeid(T).hash_code() == type.hash_code()) {
30-
return sizeof(T);
31-
} else {
32-
return 0UL;
33-
}
34-
}
35-
};
36-
37-
template <>
38-
struct SizeOfTypeFunctor<> {
39-
size_t operator()(std::type_index type) const { return 0UL; }
40-
};
41-
42-
template <typename HEAD, typename... TAIL>
43-
struct SizeOfTypeFunctor<HEAD, TAIL...> {
44-
size_t operator()(std::type_index type) const {
45-
SizeOfTypeFunctor<HEAD> head;
46-
size_t head_size = head(type);
47-
if (head_size != 0) {
48-
return head_size;
49-
}
50-
SizeOfTypeFunctor<TAIL...> tail;
51-
return tail(type);
52-
}
53-
};
54-
55-
static inline size_t SizeOfType(std::type_index type) {
56-
SizeOfTypeFunctor<int, float, double, int16_t, int64_t, bool, size_t,
57-
platform::float16>
58-
functor;
59-
size_t size = functor(type);
60-
PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
61-
return size;
62-
}
63-
23+
extern size_t SizeOfType(std::type_index type);
6424
inline void Tensor::check_memory_size() const {
6525
PADDLE_ENFORCE_NOT_NULL(
6626
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");

0 commit comments

Comments
 (0)