@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) {
22
22
23
23
auto place = CPUPlace ();
24
24
25
- Tensor in;
26
- Tensor out;
27
-
28
- float * ptr = in.mutable_data <float >(make_ddim ({2 , 3 }), place);
29
- int data_number = 2 * 3 ;
30
-
31
- for (int i = 0 ; i < data_number; ++i) {
32
- ptr[i] = i / 3 ;
33
- }
34
-
25
+ auto kernel_fp16 = OpKernelType (proto::VarType::FP16, place,
26
+ DataLayout::kAnyLayout , LibraryType::kPlain );
35
27
auto kernel_fp32 = OpKernelType (proto::VarType::FP32, place,
36
28
DataLayout::kAnyLayout , LibraryType::kPlain );
37
29
auto kernel_fp64 = OpKernelType (proto::VarType::FP64, place,
38
30
DataLayout::kAnyLayout , LibraryType::kPlain );
39
31
auto kernel_int32 = OpKernelType (proto::VarType::INT32, place,
40
32
DataLayout::kAnyLayout , LibraryType::kPlain );
33
+ auto kernel_int64 = OpKernelType (proto::VarType::INT64, place,
34
+ DataLayout::kAnyLayout , LibraryType::kPlain );
35
+ auto kernel_bool = OpKernelType (proto::VarType::BOOL, place,
36
+ DataLayout::kAnyLayout , LibraryType::kPlain );
41
37
42
- TransDataType (kernel_fp32, kernel_fp64, in, &out);
43
- double * out_data_double = out.data <double >();
44
- for (int i = 0 ; i < data_number; ++i) {
45
- ASSERT_EQ (out_data_double[i], static_cast <double >(i / 3 ));
38
+ // data type transform from float32
39
+ {
40
+ Tensor in;
41
+ Tensor out;
42
+
43
+ float * ptr = in.mutable_data <float >(make_ddim ({2 , 3 }), place);
44
+ int data_number = 2 * 3 ;
45
+
46
+ for (int i = 0 ; i < data_number; ++i) {
47
+ ptr[i] = i / 3 ;
48
+ }
49
+
50
+ TransDataType (kernel_fp32, kernel_fp64, in, &out);
51
+ double * out_data_double = out.data <double >();
52
+ for (int i = 0 ; i < data_number; ++i) {
53
+ ASSERT_EQ (out_data_double[i], static_cast <double >(i / 3 ));
54
+ }
55
+
56
+ TransDataType (kernel_fp32, kernel_int32, in, &out);
57
+ int * out_data_int = out.data <int >();
58
+ for (int i = 0 ; i < data_number; ++i) {
59
+ ASSERT_EQ (out_data_int[i], static_cast <int >(i / 3 ));
60
+ }
46
61
}
47
62
48
- TransDataType (kernel_fp32, kernel_int32, in, &out);
49
- int * out_data_int = out.data <int >();
50
- for (int i = 0 ; i < data_number; ++i) {
51
- ASSERT_EQ (out_data_int[i], static_cast <int >(i / 3 ));
63
+ // data type transform from/to float16
64
+ {
65
+ Tensor in;
66
+ Tensor out;
67
+
68
+ float16* ptr = in.mutable_data <float16>(make_ddim ({2 , 3 }), place);
69
+ int data_number = 2 * 3 ;
70
+
71
+ for (int i = 0 ; i < data_number; ++i) {
72
+ ptr[i] = i;
73
+ }
74
+
75
+ // transform from float16 to other data types
76
+ TransDataType (kernel_fp16, kernel_fp32, in, &out);
77
+ float * out_data_float = out.data <float >();
78
+ for (int i = 0 ; i < data_number; ++i) {
79
+ ASSERT_EQ (out_data_float[i], static_cast <float >(ptr[i]));
80
+ }
81
+
82
+ TransDataType (kernel_fp16, kernel_fp64, in, &out);
83
+ double * out_data_double = out.data <double >();
84
+ for (int i = 0 ; i < data_number; ++i) {
85
+ ASSERT_EQ (out_data_double[i], static_cast <double >(ptr[i]));
86
+ }
87
+
88
+ TransDataType (kernel_fp16, kernel_int32, in, &out);
89
+ int * out_data_int = out.data <int >();
90
+ for (int i = 0 ; i < data_number; ++i) {
91
+ ASSERT_EQ (out_data_int[i], static_cast <int >(ptr[i]));
92
+ }
93
+
94
+ TransDataType (kernel_fp16, kernel_int64, in, &out);
95
+ int64_t * out_data_int64 = out.data <int64_t >();
96
+ for (int i = 0 ; i < data_number; ++i) {
97
+ ASSERT_EQ (out_data_int64[i], static_cast <int64_t >(ptr[i]));
98
+ }
99
+
100
+ TransDataType (kernel_fp16, kernel_bool, in, &out);
101
+ bool * out_data_bool = out.data <bool >();
102
+ for (int i = 0 ; i < data_number; ++i) {
103
+ ASSERT_EQ (out_data_bool[i], static_cast <bool >(ptr[i]));
104
+ }
105
+
106
+ // transform float to float16
107
+ float * in_data_float = in.mutable_data <float >(make_ddim ({2 , 3 }), place);
108
+ for (int i = 0 ; i < data_number; ++i) {
109
+ in_data_float[i] = i;
110
+ }
111
+
112
+ TransDataType (kernel_fp32, kernel_fp16, in, &out);
113
+ ptr = out.data <float16>();
114
+ for (int i = 0 ; i < data_number; ++i) {
115
+ ASSERT_EQ (ptr[i].x , static_cast <float16>(in_data_float[i]).x );
116
+ }
117
+
118
+ // transform double to float16
119
+ double * in_data_double = in.mutable_data <double >(make_ddim ({2 , 3 }), place);
120
+ for (int i = 0 ; i < data_number; ++i) {
121
+ in_data_double[i] = i;
122
+ }
123
+
124
+ TransDataType (kernel_fp64, kernel_fp16, in, &out);
125
+ ptr = out.data <float16>();
126
+ for (int i = 0 ; i < data_number; ++i) {
127
+ ASSERT_EQ (ptr[i].x , static_cast <float16>(in_data_double[i]).x );
128
+ }
129
+
130
+ // transform int to float16
131
+ int * in_data_int = in.mutable_data <int >(make_ddim ({2 , 3 }), place);
132
+ for (int i = 0 ; i < data_number; ++i) {
133
+ in_data_int[i] = i;
134
+ }
135
+
136
+ TransDataType (kernel_int32, kernel_fp16, in, &out);
137
+ ptr = out.data <float16>();
138
+ for (int i = 0 ; i < data_number; ++i) {
139
+ ASSERT_EQ (ptr[i].x , static_cast <float16>(in_data_int[i]).x );
140
+ }
141
+
142
+ // transform int64 to float16
143
+ int64_t * in_data_int64 = in.mutable_data <int64_t >(make_ddim ({2 , 3 }), place);
144
+ for (int i = 0 ; i < data_number; ++i) {
145
+ in_data_int64[i] = i;
146
+ }
147
+
148
+ TransDataType (kernel_int64, kernel_fp16, in, &out);
149
+ ptr = out.data <float16>();
150
+ for (int i = 0 ; i < data_number; ++i) {
151
+ ASSERT_EQ (ptr[i].x , static_cast <float16>(in_data_int64[i]).x );
152
+ }
153
+
154
+ // transform bool to float16
155
+ bool * in_data_bool = in.mutable_data <bool >(make_ddim ({2 , 3 }), place);
156
+ for (int i = 0 ; i < data_number; ++i) {
157
+ in_data_bool[i] = i;
158
+ }
159
+
160
+ TransDataType (kernel_bool, kernel_fp16, in, &out);
161
+ ptr = out.data <float16>();
162
+ for (int i = 0 ; i < data_number; ++i) {
163
+ ASSERT_EQ (ptr[i].x , static_cast <float16>(in_data_bool[i]).x );
164
+ }
52
165
}
53
166
}
0 commit comments