Skip to content

Commit c5909c7

Browse files
committed
Feature/tensor type
test=release/1.2
1 parent 847cbdc commit c5909c7

File tree

136 files changed

+417
-587
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+417
-587
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
8585
out->mutable_data(expected_kernel_type.place_, in.type());
8686

8787
framework::VisitDataType(
88-
framework::ToDataType(in.type()),
88+
in.type(),
8989
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
9090

9191
out->set_layout(expected_kernel_type.data_layout_);
@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
101101
case mkldnn::memory::data_type::f32:
102102
return platform::to_void_cast(tensor.data<float>());
103103
case mkldnn::memory::data_type::s8:
104-
return platform::to_void_cast(tensor.data<char>());
104+
return platform::to_void_cast(tensor.data<int8_t>());
105105
case mkldnn::memory::data_type::u8:
106106
return platform::to_void_cast(tensor.data<unsigned char>());
107107
case mkldnn::memory::data_type::s16:
@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
144144

145145
memory::data_type in_type = ToMKLDNNDataType(in.type());
146146
PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
147-
"Input tensor type is not supported: ", in.type().name());
147+
"Input tensor type is not supported: %s", in.type());
148148
memory::data_type out_type = in_type;
149149

150150
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());

paddle/fluid/framework/data_layout_transform.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
5050
}
5151
}
5252

53-
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
54-
static const std::map<std::type_index, MKLDNNDataType> dict{
55-
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT
56-
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT
57-
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8},
58-
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16},
59-
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}};
60-
auto iter = dict.find(type);
53+
inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
54+
static std::unordered_map<int, MKLDNNDataType> dict{
55+
{DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
56+
{DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
57+
{DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
58+
{DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
59+
{DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
60+
auto iter = dict.find(static_cast<int>(type));
6161
if (iter != dict.end()) return iter->second;
6262
return MKLDNNDataType::data_undef;
6363
}

paddle/fluid/framework/data_type.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct DataTypeMap {
2626
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
2727
std::unordered_map<int, std::type_index> proto_to_cpp_;
2828
std::unordered_map<int, std::string> proto_to_str_;
29-
std::unordered_map<std::type_index, size_t> cpp_to_size_;
29+
std::unordered_map<int, size_t> proto_to_size_;
3030
};
3131

3232
static DataTypeMap* InitDataTypeMap();
@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
4545
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
4646
map->cpp_to_proto_.emplace(typeid(T), proto_type);
4747
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
48-
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
48+
map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
4949
}
5050

5151
static DataTypeMap* InitDataTypeMap() {
@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
5454
#define RegType(cc_type, proto_type) \
5555
RegisterType<cc_type>(retv, proto_type, #cc_type)
5656

57-
// NOTE: Add your customize type here.
58-
RegType(float16, proto::VarType::FP16);
59-
RegType(float, proto::VarType::FP32);
60-
RegType(double, proto::VarType::FP64);
61-
RegType(int, proto::VarType::INT32);
62-
RegType(int64_t, proto::VarType::INT64);
63-
RegType(bool, proto::VarType::BOOL);
64-
RegType(size_t, proto::VarType::SIZE_T);
65-
RegType(int16_t, proto::VarType::INT16);
66-
RegType(uint8_t, proto::VarType::UINT8);
67-
RegType(int8_t, proto::VarType::INT8);
57+
_ForEachDataType_(RegType);
6858

6959
#undef RegType
7060
return retv;
@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
9686
static_cast<int>(type));
9787
}
9888

99-
size_t SizeOfType(std::type_index type) {
100-
auto it = gDataTypeMap().cpp_to_size_.find(type);
101-
if (it != gDataTypeMap().cpp_to_size_.end()) {
89+
size_t SizeOfType(proto::VarType::Type type) {
90+
auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
91+
if (it != gDataTypeMap().proto_to_size_.end()) {
10292
return it->second;
10393
}
104-
PADDLE_THROW("Not support %s as tensor type", type.name());
94+
PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
10595
}
10696

10797
} // namespace framework

paddle/fluid/framework/data_type.h

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,59 @@ limitations under the License. */
2222
namespace paddle {
2323
namespace framework {
2424

25+
template <typename T>
26+
struct DataTypeTrait {};
27+
28+
// Stub handle for void
29+
template <>
30+
struct DataTypeTrait<void> {
31+
constexpr static auto DataType = proto::VarType::RAW;
32+
};
33+
34+
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
35+
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
36+
37+
#define _ForEachDataType_(callback) \
38+
_ForEachDataTypeHelper_(callback, float, FP32); \
39+
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
40+
_ForEachDataTypeHelper_(callback, double, FP64); \
41+
_ForEachDataTypeHelper_(callback, int, INT32); \
42+
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
43+
_ForEachDataTypeHelper_(callback, bool, BOOL); \
44+
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
45+
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
46+
_ForEachDataTypeHelper_(callback, int8_t, INT8)
47+
48+
#define DefineDataTypeTrait(cpp_type, proto_type) \
49+
template <> \
50+
struct DataTypeTrait<cpp_type> { \
51+
constexpr static auto DataType = proto_type; \
52+
}
53+
54+
_ForEachDataType_(DefineDataTypeTrait);
55+
56+
#undef DefineDataTypeTrait
57+
2558
extern proto::VarType::Type ToDataType(std::type_index type);
2659
extern std::type_index ToTypeIndex(proto::VarType::Type type);
2760

2861
template <typename Visitor>
2962
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
30-
switch (type) {
31-
case proto::VarType::FP16:
32-
visitor.template apply<platform::float16>();
33-
break;
34-
case proto::VarType::FP32:
35-
visitor.template apply<float>();
36-
break;
37-
case proto::VarType::FP64:
38-
visitor.template apply<double>();
39-
break;
40-
case proto::VarType::INT32:
41-
visitor.template apply<int>();
42-
break;
43-
case proto::VarType::INT64:
44-
visitor.template apply<int64_t>();
45-
break;
46-
case proto::VarType::BOOL:
47-
visitor.template apply<bool>();
48-
break;
49-
case proto::VarType::UINT8:
50-
visitor.template apply<uint8_t>();
51-
break;
52-
case proto::VarType::INT16:
53-
visitor.template apply<int16_t>();
54-
break;
55-
case proto::VarType::INT8:
56-
visitor.template apply<int8_t>();
57-
break;
58-
default:
59-
PADDLE_THROW("Not supported %d", type);
60-
}
63+
#define VisitDataTypeCallback(cpp_type, proto_type) \
64+
do { \
65+
if (type == proto_type) { \
66+
visitor.template apply<cpp_type>(); \
67+
return; \
68+
} \
69+
} while (0)
70+
71+
_ForEachDataType_(VisitDataTypeCallback);
72+
#undef VisitDataTypeCallback
73+
PADDLE_THROW("Not supported %d", type);
6174
}
6275

6376
extern std::string DataTypeToString(const proto::VarType::Type type);
64-
extern size_t SizeOfType(std::type_index type);
77+
extern size_t SizeOfType(proto::VarType::Type type);
6578
inline std::ostream& operator<<(std::ostream& out,
6679
const proto::VarType::Type& type) {
6780
out << DataTypeToString(type);

paddle/fluid/framework/data_type_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ TEST(DataType, float16) {
2626

2727
Tensor tensor;
2828
CPUPlace cpu;
29-
tensor.mutable_data(cpu, f::ToTypeIndex(dtype));
29+
tensor.mutable_data(cpu, dtype);
3030

3131
// test fp16 tensor
32-
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16)));
32+
EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
3333

3434
// test fp16 size
35-
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u);
35+
EXPECT_EQ(f::SizeOfType(dtype), 2u);
3636

3737
// test debug info
38-
std::string type = "float16";
38+
std::string type = "::paddle::platform::float16";
3939
EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str());
4040
}

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void AllReduceOpHandle::RunImpl() {
120120

121121
// Reduce All Tensor to trg in CPU
122122
ReduceLoDTensor func(lod_tensors, &trg);
123-
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
123+
VisitDataType(lod_tensors[0]->type(), func);
124124

125125
for (size_t i = 1; i < local_scopes_.size(); ++i) {
126126
auto &scope =

paddle/fluid/framework/details/fuse_vars_op_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
3333
FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
3434
const platform::Place &place,
3535
const std::unordered_map<std::string, int64_t> &inputs_numel,
36-
const std::type_index &var_type)
36+
const proto::VarType::Type var_type)
3737
: OpHandleBase(node),
3838
local_scope_(local_scope),
3939
place_(place),
@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
5757
Scope *local_scope_;
5858
const platform::Place place_;
5959
const std::unordered_map<std::string, int64_t> inputs_numel_;
60-
const std::type_index type_;
60+
const proto::VarType::Type type_;
6161
int64_t total_numel_;
6262
};
6363
} // namespace details

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void ReduceOpHandle::RunImpl() {
106106
if (!FLAGS_cpu_deterministic) {
107107
ReduceLoDTensor func(lod_tensors,
108108
out_var->GetMutable<framework::LoDTensor>());
109-
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
109+
VisitDataType(lod_tensors[0]->type(), func);
110110
} else {
111111
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
112112
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
@@ -116,7 +116,7 @@ void ReduceOpHandle::RunImpl() {
116116
->FindVar(out_var_handle->name_)
117117
->GetMutable<framework::LoDTensor>();
118118
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
119-
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
119+
VisitDataType(lod_tensors[0]->type(), func);
120120

121121
auto trg = out_var->GetMutable<framework::LoDTensor>();
122122
if (reduce_sum_trg.data<void>() != trg->data<void>()) {

paddle/fluid/framework/dlpack_tensor.cc

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/dlpack_tensor.h"
16-
16+
#include "paddle/fluid/framework/data_type.h"
1717
namespace paddle {
1818
namespace framework {
1919

@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
3636
return dtype;
3737
}
3838

39-
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) {
40-
#define REG_DL_DATA_TYPE(type) \
41-
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
42-
static const std::unordered_map<std::type_index, ::DLDataType>
43-
type_to_dtype_map({
44-
REG_DL_DATA_TYPE(platform::float16), // NOLINT
45-
REG_DL_DATA_TYPE(float), // NOLINT
46-
REG_DL_DATA_TYPE(double), // NOLINT
47-
REG_DL_DATA_TYPE(int), // NOLINT
48-
REG_DL_DATA_TYPE(int64_t), // NOLINT
49-
REG_DL_DATA_TYPE(bool), // NOLINT
50-
REG_DL_DATA_TYPE(size_t), // NOLINT
51-
REG_DL_DATA_TYPE(int16_t), // NOLINT
52-
REG_DL_DATA_TYPE(uint8_t), // NOLINT
53-
REG_DL_DATA_TYPE(int8_t) // NOLINT
54-
});
39+
static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
40+
static std::unordered_map<int, ::DLDataType> result;
41+
42+
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
43+
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
44+
45+
_ForEachDataType_(REG_DL_DATA_TYPE);
46+
#undef REG_DL_DATA_TYPE
47+
return result;
48+
}
49+
50+
static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
51+
static auto type_to_dtype_map = CreateDLDataTypeMap();
5552
static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
56-
auto it = type_to_dtype_map.find(type);
57-
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s",
58-
type.name());
53+
auto it = type_to_dtype_map.find(static_cast<int>(type));
54+
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
55+
type);
5956
return it->second;
6057
#undef REG_DL_DATA_TYPE
6158
}

paddle/fluid/framework/dlpack_tensor_test.cc

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,11 @@ void TestMainLoop() {
9191
}
9292
}
9393
}
94+
TEST(dlpack, test_all) {
95+
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
9496

95-
#define PADDLE_DLPACK_TEST(type) \
96-
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
97-
98-
using float16 = platform::float16;
99-
PADDLE_DLPACK_TEST(float16);
100-
PADDLE_DLPACK_TEST(float);
101-
PADDLE_DLPACK_TEST(double);
102-
PADDLE_DLPACK_TEST(int);
103-
PADDLE_DLPACK_TEST(int64_t);
104-
PADDLE_DLPACK_TEST(bool);
105-
PADDLE_DLPACK_TEST(size_t);
106-
PADDLE_DLPACK_TEST(int16_t);
107-
PADDLE_DLPACK_TEST(uint8_t);
108-
PADDLE_DLPACK_TEST(int8_t);
109-
110-
#undef PADDLE_DLPACK_TEST
97+
_ForEachDataType_(TestCallback);
98+
}
11199

112100
} // namespace framework
113101
} // namespace paddle

0 commit comments

Comments
 (0)