Skip to content

Commit 266ccaa

Browse files
authored
Integrate float16 into data_type_transform (#8619)
* test cpu float16 data transform * add isnan etc * small fix * fix containsNAN test error * add data_type transform GPU test * add float16 GPU example * fix error * fix GPU test error * add context wait
1 parent 78c884d commit 266ccaa

File tree

11 files changed

+527
-65
lines changed

11 files changed

+527
-65
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
55
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
66
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
77

8-
if (WITH_GPU)
8+
if(WITH_GPU)
99
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
1010
else()
1111
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
12-
endif ()
12+
endif()
1313

1414
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
15-
if (WITH_GPU)
15+
if(WITH_GPU)
1616
nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor)
1717
else()
1818
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
@@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
3939
nv_test(data_device_transform_test SRCS data_device_transform_test.cu
4040
DEPS operator op_registry init math_function)
4141

42-
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
43-
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
42+
if(WITH_GPU)
43+
nv_library(data_type_transform SRCS data_type_transform.cu DEPS tensor)
44+
nv_test(data_type_transform_test SRCS data_type_transform_test.cc data_type_transform_test.cu DEPS data_type_transform)
45+
else()
46+
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
47+
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
48+
endif()
4449

4550
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
4651
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)

paddle/fluid/framework/data_transform.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type,
4242
PassTensorData(&out, &in);
4343
}
4444

45+
// do data type transform
4546
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
4647
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
4748
transformed = true;

paddle/fluid/framework/data_type.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ limitations under the License. */
1616
#include <typeindex>
1717
#include "paddle/fluid/framework/framework.pb.h"
1818
#include "paddle/fluid/platform/enforce.h"
19+
#include "paddle/fluid/platform/float16.h"
1920

2021
namespace paddle {
2122
namespace framework {
2223

2324
inline proto::VarType::Type ToDataType(std::type_index type) {
2425
using namespace paddle::framework::proto;
25-
if (typeid(float).hash_code() == type.hash_code()) {
26+
if (typeid(platform::float16).hash_code() == type.hash_code()) {
27+
return proto::VarType::FP16;
28+
} else if (typeid(float).hash_code() == type.hash_code()) {
2629
return proto::VarType::FP32;
2730
} else if (typeid(double).hash_code() == type.hash_code()) {
2831
return proto::VarType::FP64;
@@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
4043
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
4144
using namespace paddle::framework::proto;
4245
switch (type) {
46+
case proto::VarType::FP16:
47+
return typeid(platform::float16);
4348
case proto::VarType::FP32:
4449
return typeid(float);
4550
case proto::VarType::FP64:
@@ -59,6 +64,9 @@ template <typename Visitor>
5964
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
6065
using namespace paddle::framework::proto;
6166
switch (type) {
67+
case proto::VarType::FP16:
68+
visitor.template operator()<platform::float16>();
69+
break;
6270
case proto::VarType::FP32:
6371
visitor.template operator()<float>();
6472
break;

paddle/fluid/framework/data_type_transform.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,15 @@ struct CastDataType {
4747
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
4848
trans(*context, in_begin, in_end, out_begin,
4949
CastDataTypeFunctor<InType, OutType>());
50+
#ifdef __NVCC__
51+
} else if (platform::is_gpu_place(in_.place())) {
52+
platform::Transform<platform::CUDADeviceContext> trans;
53+
auto* context = static_cast<const platform::CUDADeviceContext*>(ctx_);
54+
trans(*context, in_begin, in_end, out_begin,
55+
CastDataTypeFunctor<InType, OutType>());
56+
#endif
5057
} else {
51-
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type?
52-
PADDLE_THROW("Unsupport CPU <-> GPU!");
58+
PADDLE_THROW("Unsupported place!");
5359
}
5460
}
5561
};
@@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
6571
auto ctx = pool.Get(in.place());
6672

6773
switch (src_type) {
74+
case proto::VarType::FP16:
75+
framework::VisitDataType(dst_type,
76+
CastDataType<platform::float16>(in, out, ctx));
77+
break;
6878
case proto::VarType::FP32:
6979
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
7080
break;
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data_type_transform.cc

paddle/fluid/framework/data_type_transform_test.cc

Lines changed: 131 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) {
2222

2323
auto place = CPUPlace();
2424

25-
Tensor in;
26-
Tensor out;
27-
28-
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
29-
int data_number = 2 * 3;
30-
31-
for (int i = 0; i < data_number; ++i) {
32-
ptr[i] = i / 3;
33-
}
34-
25+
auto kernel_fp16 = OpKernelType(proto::VarType::FP16, place,
26+
DataLayout::kAnyLayout, LibraryType::kPlain);
3527
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
3628
DataLayout::kAnyLayout, LibraryType::kPlain);
3729
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
3830
DataLayout::kAnyLayout, LibraryType::kPlain);
3931
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
4032
DataLayout::kAnyLayout, LibraryType::kPlain);
33+
auto kernel_int64 = OpKernelType(proto::VarType::INT64, place,
34+
DataLayout::kAnyLayout, LibraryType::kPlain);
35+
auto kernel_bool = OpKernelType(proto::VarType::BOOL, place,
36+
DataLayout::kAnyLayout, LibraryType::kPlain);
4137

42-
TransDataType(kernel_fp32, kernel_fp64, in, &out);
43-
double* out_data_double = out.data<double>();
44-
for (int i = 0; i < data_number; ++i) {
45-
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
38+
// data type transform from float32
39+
{
40+
Tensor in;
41+
Tensor out;
42+
43+
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
44+
int data_number = 2 * 3;
45+
46+
for (int i = 0; i < data_number; ++i) {
47+
ptr[i] = i / 3;
48+
}
49+
50+
TransDataType(kernel_fp32, kernel_fp64, in, &out);
51+
double* out_data_double = out.data<double>();
52+
for (int i = 0; i < data_number; ++i) {
53+
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
54+
}
55+
56+
TransDataType(kernel_fp32, kernel_int32, in, &out);
57+
int* out_data_int = out.data<int>();
58+
for (int i = 0; i < data_number; ++i) {
59+
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
60+
}
4661
}
4762

48-
TransDataType(kernel_fp32, kernel_int32, in, &out);
49-
int* out_data_int = out.data<int>();
50-
for (int i = 0; i < data_number; ++i) {
51-
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
63+
// data type transform from/to float16
64+
{
65+
Tensor in;
66+
Tensor out;
67+
68+
float16* ptr = in.mutable_data<float16>(make_ddim({2, 3}), place);
69+
int data_number = 2 * 3;
70+
71+
for (int i = 0; i < data_number; ++i) {
72+
ptr[i] = i;
73+
}
74+
75+
// transform from float16 to other data types
76+
TransDataType(kernel_fp16, kernel_fp32, in, &out);
77+
float* out_data_float = out.data<float>();
78+
for (int i = 0; i < data_number; ++i) {
79+
ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
80+
}
81+
82+
TransDataType(kernel_fp16, kernel_fp64, in, &out);
83+
double* out_data_double = out.data<double>();
84+
for (int i = 0; i < data_number; ++i) {
85+
ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
86+
}
87+
88+
TransDataType(kernel_fp16, kernel_int32, in, &out);
89+
int* out_data_int = out.data<int>();
90+
for (int i = 0; i < data_number; ++i) {
91+
ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
92+
}
93+
94+
TransDataType(kernel_fp16, kernel_int64, in, &out);
95+
int64_t* out_data_int64 = out.data<int64_t>();
96+
for (int i = 0; i < data_number; ++i) {
97+
ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
98+
}
99+
100+
TransDataType(kernel_fp16, kernel_bool, in, &out);
101+
bool* out_data_bool = out.data<bool>();
102+
for (int i = 0; i < data_number; ++i) {
103+
ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
104+
}
105+
106+
// transform float to float16
107+
float* in_data_float = in.mutable_data<float>(make_ddim({2, 3}), place);
108+
for (int i = 0; i < data_number; ++i) {
109+
in_data_float[i] = i;
110+
}
111+
112+
TransDataType(kernel_fp32, kernel_fp16, in, &out);
113+
ptr = out.data<float16>();
114+
for (int i = 0; i < data_number; ++i) {
115+
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x);
116+
}
117+
118+
// transform double to float16
119+
double* in_data_double = in.mutable_data<double>(make_ddim({2, 3}), place);
120+
for (int i = 0; i < data_number; ++i) {
121+
in_data_double[i] = i;
122+
}
123+
124+
TransDataType(kernel_fp64, kernel_fp16, in, &out);
125+
ptr = out.data<float16>();
126+
for (int i = 0; i < data_number; ++i) {
127+
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x);
128+
}
129+
130+
// transform int to float16
131+
int* in_data_int = in.mutable_data<int>(make_ddim({2, 3}), place);
132+
for (int i = 0; i < data_number; ++i) {
133+
in_data_int[i] = i;
134+
}
135+
136+
TransDataType(kernel_int32, kernel_fp16, in, &out);
137+
ptr = out.data<float16>();
138+
for (int i = 0; i < data_number; ++i) {
139+
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x);
140+
}
141+
142+
// transform int64 to float16
143+
int64_t* in_data_int64 = in.mutable_data<int64_t>(make_ddim({2, 3}), place);
144+
for (int i = 0; i < data_number; ++i) {
145+
in_data_int64[i] = i;
146+
}
147+
148+
TransDataType(kernel_int64, kernel_fp16, in, &out);
149+
ptr = out.data<float16>();
150+
for (int i = 0; i < data_number; ++i) {
151+
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x);
152+
}
153+
154+
// transform bool to float16
155+
bool* in_data_bool = in.mutable_data<bool>(make_ddim({2, 3}), place);
156+
for (int i = 0; i < data_number; ++i) {
157+
in_data_bool[i] = i;
158+
}
159+
160+
TransDataType(kernel_bool, kernel_fp16, in, &out);
161+
ptr = out.data<float16>();
162+
for (int i = 0; i < data_number; ++i) {
163+
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x);
164+
}
52165
}
53166
}

0 commit comments

Comments
 (0)