@@ -27,72 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
27
#include <stdio.h>
28
28
#include <stdint.h>
29
29
#include "../common.h"
30
+
31
+ #include "test_helpers.h"
32
+
30
33
#define SGEMM BLASFUNC(sgemm)
31
34
#define SBGEMM BLASFUNC(sbgemm)
32
35
#define SGEMV BLASFUNC(sgemv)
33
36
#define SBGEMV BLASFUNC(sbgemv)
34
- typedef union
35
- {
36
- unsigned short v ;
37
- #if defined(_AIX )
38
- struct __attribute__((packed ))
39
- #else
40
- struct
41
- #endif
42
- {
43
- #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
44
- unsigned short s :1 ;
45
- unsigned short e :8 ;
46
- unsigned short m :7 ;
47
- #else
48
- unsigned short m :7 ;
49
- unsigned short e :8 ;
50
- unsigned short s :1 ;
51
- #endif
52
- } bits ;
53
- } bfloat16_bits ;
54
-
55
- typedef union
56
- {
57
- float v ;
58
- #if defined(_AIX )
59
- struct __attribute__((packed ))
60
- #else
61
- struct
62
- #endif
63
- {
64
- #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
65
- uint32_t s :1 ;
66
- uint32_t e :8 ;
67
- uint32_t m :23 ;
68
- #else
69
- uint32_t m :23 ;
70
- uint32_t e :8 ;
71
- uint32_t s :1 ;
72
- #endif
73
- } bits ;
74
- } float32_bits ;
75
-
76
- float
77
- float16to32 (bfloat16_bits f16 )
78
- {
79
- float32_bits f32 ;
80
- f32 .bits .s = f16 .bits .s ;
81
- f32 .bits .e = f16 .bits .e ;
82
- f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
83
- return f32 .v ;
84
- }
85
-
86
37
#define SBGEMM_LARGEST 256
87
38
88
- void * malloc_safe (size_t size )
89
- {
90
- if (size == 0 )
91
- return malloc (1 );
92
- else
93
- return malloc (size );
94
- }
95
-
96
39
int
97
40
main (int argc , char * argv [])
98
41
{
@@ -111,32 +54,29 @@ main (int argc, char *argv[])
111
54
float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
112
55
float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
113
56
float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
114
- bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (m * k * sizeof (bfloat16_bits ));
115
- bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (k * n * sizeof (bfloat16_bits ));
57
+ bfloat16 * AA = (bfloat16 * )malloc_safe (m * k * sizeof (bfloat16 ));
58
+ bfloat16 * BB = (bfloat16 * )malloc_safe (k * n * sizeof (bfloat16 ));
116
59
float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
117
60
float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
118
61
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
119
62
(DD == NULL ) || (CC == NULL ))
120
63
return 1 ;
121
- bfloat16 atmp ,btmp ;
122
64
blasint one = 1 ;
123
65
124
66
for (j = 0 ; j < m ; j ++ )
125
67
{
126
68
for (i = 0 ; i < k ; i ++ )
127
69
{
128
70
A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
129
- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
130
- AA [j * k + i ].v = atmp ;
71
+ sbstobf16_ (& one , & A [j * k + i ], & one , & AA [j * k + i ], & one );
131
72
}
132
73
}
133
74
for (j = 0 ; j < n ; j ++ )
134
75
{
135
76
for (i = 0 ; i < k ; i ++ )
136
77
{
137
78
B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
138
- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
139
- BB [j * k + i ].v = btmp ;
79
+ sbstobf16_ (& one , & B [j * k + i ], & one , & BB [j * k + i ], & one );
140
80
}
141
81
}
142
82
for (y = 0 ; y < 4 ; y ++ )
@@ -182,10 +122,12 @@ main (int argc, char *argv[])
182
122
DD [i * m + j ] +=
183
123
float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
184
124
}
185
- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
125
+ if (! is_close (CC [i * m + j ], C [i * m + j ], 0.01 , 0.001 )) {
186
126
ret ++ ;
187
- if (fabs (CC [i * m + j ] - DD [i * m + j ]) > 1.0 )
127
+ }
128
+ if (!is_close (CC [i * m + j ], DD [i * m + j ], 0.001 , 0.0001 )) {
188
129
ret ++ ;
130
+ }
189
131
}
190
132
}
191
133
free (A );
@@ -211,27 +153,24 @@ main (int argc, char *argv[])
211
153
float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
212
154
float * B = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
213
155
float * C = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
214
- bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
215
- bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ) << l );
156
+ bfloat16 * AA = (bfloat16 * )malloc_safe (x * x * sizeof (bfloat16 ));
157
+ bfloat16 * BB = (bfloat16 * )malloc_safe (x * sizeof (bfloat16 ) << l );
216
158
float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
217
159
float * CC = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
218
160
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
219
161
(DD == NULL ) || (CC == NULL ))
220
162
return 1 ;
221
- bfloat16 atmp , btmp ;
222
163
blasint one = 1 ;
223
164
224
165
for (j = 0 ; j < x ; j ++ )
225
166
{
226
167
for (i = 0 ; i < x ; i ++ )
227
168
{
228
169
A [j * x + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
229
- sbstobf16_ (& one , & A [j * x + i ], & one , & atmp , & one );
230
- AA [j * x + i ].v = atmp ;
170
+ sbstobf16_ (& one , & A [j * x + i ], & one , & AA [j * x + i ], & one );
231
171
}
232
172
B [j << l ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
233
- sbstobf16_ (& one , & B [j << l ], & one , & btmp , & one );
234
- BB [j << l ].v = btmp ;
173
+ sbstobf16_ (& one , & B [j << l ], & one , & BB [j << l ], & one );
235
174
236
175
CC [j << l ] = C [j << l ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
237
176
}
@@ -262,10 +201,12 @@ main (int argc, char *argv[])
262
201
}
263
202
264
203
for (j = 0 ; j < x ; j ++ ) {
265
- if (fabs (CC [j << l ] - C [j << l ]) > 1.0 )
204
+ if (! is_close (CC [j << l ], C [j << l ], 0.01 , 0.001 )) {
266
205
ret ++ ;
267
- if (fabs (CC [j << l ] - DD [j ]) > 1.0 )
206
+ }
207
+ if (!is_close (CC [j << l ], DD [j ], 0.001 , 0.0001 )) {
268
208
ret ++ ;
209
+ }
269
210
}
270
211
}
271
212
free (A );
0 commit comments