@@ -17,43 +17,58 @@ limitations under the License. */
17
17
#include " gtest/gtest.h"
18
18
19
19
TEST (DataTypeTransform, CPUTransform) {
20
- using namespace paddle ::framework;
21
- using namespace paddle ::platform;
22
-
23
- auto place = CPUPlace ();
24
-
25
- auto kernel_fp16 = OpKernelType (proto::VarType::FP16, place,
26
- DataLayout::kAnyLayout , LibraryType::kPlain );
27
- auto kernel_fp32 = OpKernelType (proto::VarType::FP32, place,
28
- DataLayout::kAnyLayout , LibraryType::kPlain );
29
- auto kernel_fp64 = OpKernelType (proto::VarType::FP64, place,
30
- DataLayout::kAnyLayout , LibraryType::kPlain );
31
- auto kernel_int32 = OpKernelType (proto::VarType::INT32, place,
32
- DataLayout::kAnyLayout , LibraryType::kPlain );
33
- auto kernel_int64 = OpKernelType (proto::VarType::INT64, place,
34
- DataLayout::kAnyLayout , LibraryType::kPlain );
35
- auto kernel_bool = OpKernelType (proto::VarType::BOOL, place,
36
- DataLayout::kAnyLayout , LibraryType::kPlain );
20
+ auto place = paddle::platform::CPUPlace ();
21
+
22
+ auto kernel_fp16 = paddle::framework::OpKernelType (
23
+ paddle::framework::proto::VarType::FP16, place,
24
+ paddle::framework::DataLayout::kAnyLayout ,
25
+ paddle::framework::LibraryType::kPlain );
26
+
27
+ auto kernel_fp32 = paddle::framework::OpKernelType (
28
+ paddle::framework::proto::VarType::FP32, place,
29
+ paddle::framework::DataLayout::kAnyLayout ,
30
+ paddle::framework::LibraryType::kPlain );
31
+
32
+ auto kernel_fp64 = paddle::framework::OpKernelType (
33
+ paddle::framework::proto::VarType::FP64, place,
34
+ paddle::framework::DataLayout::kAnyLayout ,
35
+ paddle::framework::LibraryType::kPlain );
36
+
37
+ auto kernel_int32 = paddle::framework::OpKernelType (
38
+ paddle::framework::proto::VarType::INT32, place,
39
+ paddle::framework::DataLayout::kAnyLayout ,
40
+ paddle::framework::LibraryType::kPlain );
41
+
42
+ auto kernel_int64 = paddle::framework::OpKernelType (
43
+ paddle::framework::proto::VarType::INT64, place,
44
+ paddle::framework::DataLayout::kAnyLayout ,
45
+ paddle::framework::LibraryType::kPlain );
46
+
47
+ auto kernel_bool = paddle::framework::OpKernelType (
48
+ paddle::framework::proto::VarType::BOOL, place,
49
+ paddle::framework::DataLayout::kAnyLayout ,
50
+ paddle::framework::LibraryType::kPlain );
37
51
38
52
// data type transform from float32
39
53
{
40
- Tensor in;
41
- Tensor out;
54
+ paddle::framework:: Tensor in;
55
+ paddle::framework:: Tensor out;
42
56
43
- float * ptr = in.mutable_data <float >(make_ddim ({2 , 3 }), place);
57
+ float * ptr =
58
+ in.mutable_data <float >(paddle::framework::make_ddim ({2 , 3 }), place);
44
59
int data_number = 2 * 3 ;
45
60
46
61
for (int i = 0 ; i < data_number; ++i) {
47
62
ptr[i] = i / 3 ;
48
63
}
49
64
50
- TransDataType (kernel_fp32, kernel_fp64, in, &out);
65
+ paddle::framework:: TransDataType (kernel_fp32, kernel_fp64, in, &out);
51
66
double * out_data_double = out.data <double >();
52
67
for (int i = 0 ; i < data_number; ++i) {
53
68
EXPECT_EQ (out_data_double[i], static_cast <double >(i / 3 ));
54
69
}
55
70
56
- TransDataType (kernel_fp32, kernel_int32, in, &out);
71
+ paddle::framework:: TransDataType (kernel_fp32, kernel_int32, in, &out);
57
72
int * out_data_int = out.data <int >();
58
73
for (int i = 0 ; i < data_number; ++i) {
59
74
EXPECT_EQ (out_data_int[i], static_cast <int >(i / 3 ));
@@ -62,105 +77,116 @@ TEST(DataTypeTransform, CPUTransform) {
62
77
63
78
// data type transform from/to float16
64
79
{
65
- Tensor in;
66
- Tensor out;
80
+ paddle::framework:: Tensor in;
81
+ paddle::framework:: Tensor out;
67
82
68
- float16* ptr = in.mutable_data <float16>(make_ddim ({2 , 3 }), place);
83
+ paddle::platform::float16* ptr = in.mutable_data <paddle::platform::float16>(
84
+ paddle::framework::make_ddim ({2 , 3 }), place);
69
85
int data_number = 2 * 3 ;
70
86
71
87
for (int i = 0 ; i < data_number; ++i) {
72
88
ptr[i] = i;
73
89
}
74
90
75
91
// transform from float16 to other data types
76
- TransDataType (kernel_fp16, kernel_fp32, in, &out);
92
+ paddle::framework:: TransDataType (kernel_fp16, kernel_fp32, in, &out);
77
93
float * out_data_float = out.data <float >();
78
94
for (int i = 0 ; i < data_number; ++i) {
79
95
EXPECT_EQ (out_data_float[i], static_cast <float >(ptr[i]));
80
96
}
81
97
82
- TransDataType (kernel_fp16, kernel_fp64, in, &out);
98
+ paddle::framework:: TransDataType (kernel_fp16, kernel_fp64, in, &out);
83
99
double * out_data_double = out.data <double >();
84
100
for (int i = 0 ; i < data_number; ++i) {
85
101
EXPECT_EQ (out_data_double[i], static_cast <double >(ptr[i]));
86
102
}
87
103
88
- TransDataType (kernel_fp16, kernel_int32, in, &out);
104
+ paddle::framework:: TransDataType (kernel_fp16, kernel_int32, in, &out);
89
105
int * out_data_int = out.data <int >();
90
106
for (int i = 0 ; i < data_number; ++i) {
91
107
EXPECT_EQ (out_data_int[i], static_cast <int >(ptr[i]));
92
108
}
93
109
94
- TransDataType (kernel_fp16, kernel_int64, in, &out);
110
+ paddle::framework:: TransDataType (kernel_fp16, kernel_int64, in, &out);
95
111
int64_t * out_data_int64 = out.data <int64_t >();
96
112
for (int i = 0 ; i < data_number; ++i) {
97
113
EXPECT_EQ (out_data_int64[i], static_cast <int64_t >(ptr[i]));
98
114
}
99
115
100
- TransDataType (kernel_fp16, kernel_bool, in, &out);
116
+ paddle::framework:: TransDataType (kernel_fp16, kernel_bool, in, &out);
101
117
bool * out_data_bool = out.data <bool >();
102
118
for (int i = 0 ; i < data_number; ++i) {
103
119
EXPECT_EQ (out_data_bool[i], static_cast <bool >(ptr[i]));
104
120
}
105
121
106
122
// transform float to float16
107
- float * in_data_float = in.mutable_data <float >(make_ddim ({2 , 3 }), place);
123
+ float * in_data_float =
124
+ in.mutable_data <float >(paddle::framework::make_ddim ({2 , 3 }), place);
108
125
for (int i = 0 ; i < data_number; ++i) {
109
126
in_data_float[i] = i;
110
127
}
111
128
112
- TransDataType (kernel_fp32, kernel_fp16, in, &out);
113
- ptr = out.data <float16>();
129
+ paddle::framework:: TransDataType (kernel_fp32, kernel_fp16, in, &out);
130
+ ptr = out.data <paddle::platform:: float16>();
114
131
for (int i = 0 ; i < data_number; ++i) {
115
- EXPECT_EQ (ptr[i].x , static_cast <float16>(in_data_float[i]).x );
132
+ EXPECT_EQ (ptr[i].x ,
133
+ static_cast <paddle::platform::float16>(in_data_float[i]).x );
116
134
}
117
135
118
136
// transform double to float16
119
- double * in_data_double = in.mutable_data <double >(make_ddim ({2 , 3 }), place);
137
+ double * in_data_double =
138
+ in.mutable_data <double >(paddle::framework::make_ddim ({2 , 3 }), place);
120
139
for (int i = 0 ; i < data_number; ++i) {
121
140
in_data_double[i] = i;
122
141
}
123
142
124
- TransDataType (kernel_fp64, kernel_fp16, in, &out);
125
- ptr = out.data <float16>();
143
+ paddle::framework:: TransDataType (kernel_fp64, kernel_fp16, in, &out);
144
+ ptr = out.data <paddle::platform:: float16>();
126
145
for (int i = 0 ; i < data_number; ++i) {
127
- EXPECT_EQ (ptr[i].x , static_cast <float16>(in_data_double[i]).x );
146
+ EXPECT_EQ (ptr[i].x ,
147
+ static_cast <paddle::platform::float16>(in_data_double[i]).x );
128
148
}
129
149
130
150
// transform int to float16
131
- int * in_data_int = in.mutable_data <int >(make_ddim ({2 , 3 }), place);
151
+ int * in_data_int =
152
+ in.mutable_data <int >(paddle::framework::make_ddim ({2 , 3 }), place);
132
153
for (int i = 0 ; i < data_number; ++i) {
133
154
in_data_int[i] = i;
134
155
}
135
156
136
- TransDataType (kernel_int32, kernel_fp16, in, &out);
137
- ptr = out.data <float16>();
157
+ paddle::framework:: TransDataType (kernel_int32, kernel_fp16, in, &out);
158
+ ptr = out.data <paddle::platform:: float16>();
138
159
for (int i = 0 ; i < data_number; ++i) {
139
- EXPECT_EQ (ptr[i].x , static_cast <float16>(in_data_int[i]).x );
160
+ EXPECT_EQ (ptr[i].x ,
161
+ static_cast <paddle::platform::float16>(in_data_int[i]).x );
140
162
}
141
163
142
164
// transform int64 to float16
143
- int64_t * in_data_int64 = in.mutable_data <int64_t >(make_ddim ({2 , 3 }), place);
165
+ int64_t * in_data_int64 =
166
+ in.mutable_data <int64_t >(paddle::framework::make_ddim ({2 , 3 }), place);
144
167
for (int i = 0 ; i < data_number; ++i) {
145
168
in_data_int64[i] = i;
146
169
}
147
170
148
- TransDataType (kernel_int64, kernel_fp16, in, &out);
149
- ptr = out.data <float16>();
171
+ paddle::framework:: TransDataType (kernel_int64, kernel_fp16, in, &out);
172
+ ptr = out.data <paddle::platform:: float16>();
150
173
for (int i = 0 ; i < data_number; ++i) {
151
- EXPECT_EQ (ptr[i].x , static_cast <float16>(in_data_int64[i]).x );
174
+ EXPECT_EQ (ptr[i].x ,
175
+ static_cast <paddle::platform::float16>(in_data_int64[i]).x );
152
176
}
153
177
154
178
// transform bool to float16
155
- bool * in_data_bool = in.mutable_data <bool >(make_ddim ({2 , 3 }), place);
179
+ bool * in_data_bool =
180
+ in.mutable_data <bool >(paddle::framework::make_ddim ({2 , 3 }), place);
156
181
for (int i = 0 ; i < data_number; ++i) {
157
182
in_data_bool[i] = i;
158
183
}
159
184
160
- TransDataType (kernel_bool, kernel_fp16, in, &out);
161
- ptr = out.data <float16>();
185
+ paddle::framework:: TransDataType (kernel_bool, kernel_fp16, in, &out);
186
+ ptr = out.data <paddle::platform:: float16>();
162
187
for (int i = 0 ; i < data_number; ++i) {
163
- EXPECT_EQ (ptr[i].x , static_cast <float16>(in_data_bool[i]).x );
188
+ EXPECT_EQ (ptr[i].x ,
189
+ static_cast <paddle::platform::float16>(in_data_bool[i]).x );
164
190
}
165
191
}
166
192
}
0 commit comments