@@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
29
#include "../common.h"
30
30
#define SGEMM BLASFUNC(sgemm)
31
31
#define SBGEMM BLASFUNC(sbgemm)
32
+ #define SGEMV BLASFUNC(sgemv)
33
+ #define SBGEMV BLASFUNC(sbgemv)
32
34
typedef union
33
35
{
34
36
unsigned short v ;
@@ -187,7 +189,79 @@ main (int argc, char *argv[])
187
189
free (CC );
188
190
}
189
191
190
- if (ret != 0 )
192
+ if (ret != 0 ) {
191
193
fprintf (stderr , "FATAL ERROR SBGEMM - Return code: %d\n" , ret );
194
+ return ret ;
195
+ }
196
+
197
+ k = 1 ;
198
+ for (x = 1 ; x <= loop ; x ++ )
199
+ {
200
+ float * A = (float * )malloc (x * x * sizeof (FLOAT ));
201
+ float * B = (float * )malloc (x * sizeof (FLOAT ));
202
+ float * C = (float * )malloc (x * sizeof (FLOAT ));
203
+ bfloat16_bits * AA = (bfloat16_bits * )malloc (x * x * sizeof (bfloat16_bits ));
204
+ bfloat16_bits * BB = (bfloat16_bits * )malloc (x * sizeof (bfloat16_bits ));
205
+ float * DD = (float * )malloc (x * sizeof (FLOAT ));
206
+ float * CC = (float * )malloc (x * sizeof (FLOAT ));
207
+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
208
+ (DD == NULL ) || (CC == NULL ))
209
+ return 1 ;
210
+ bfloat16 atmp , btmp ;
211
+ blasint one = 1 ;
212
+
213
+ for (j = 0 ; j < x ; j ++ )
214
+ {
215
+ for (i = 0 ; i < x ; i ++ )
216
+ {
217
+ A [j * x + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
218
+ sbstobf16_ (& one , & A [j * x + i ], & one , & atmp , & one );
219
+ AA [j * x + i ].v = atmp ;
220
+ }
221
+ B [j ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
222
+ sbstobf16_ (& one , & B [j ], & one , & btmp , & one );
223
+ BB [j ].v = btmp ;
224
+ }
225
+ for (y = 0 ; y < 2 ; y ++ )
226
+ {
227
+ if (y == 0 ) {
228
+ transA = 'N' ;
229
+ } else {
230
+ transA = 'T' ;
231
+ }
232
+
233
+ memset (CC , 0 , x * sizeof (FLOAT ));
234
+ memset (DD , 0 , x * sizeof (FLOAT ));
235
+ memset (C , 0 , x * sizeof (FLOAT ));
236
+
237
+ SGEMV (& transA , & x , & x , & alpha , A , & x , B , & k , & beta , C , & k );
238
+ SBGEMV (& transA , & x , & x , & alpha , (bfloat16 * ) AA , & x , (bfloat16 * ) BB , & k , & beta , CC , & k );
239
+
240
+ for (j = 0 ; j < x ; j ++ )
241
+ for (i = 0 ; i < x ; i ++ )
242
+ if (transA == 'N' ) {
243
+ DD [i ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [j ]);
244
+ } else if (transA == 'T' ) {
245
+ DD [j ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [i ]);
246
+ }
247
+
248
+ for (j = 0 ; j < x ; j ++ ) {
249
+ if (fabs (CC [j ] - C [j ]) > 1.0 )
250
+ ret ++ ;
251
+ if (fabs (CC [j ] - DD [j ]) > 1.0 )
252
+ ret ++ ;
253
+ }
254
+ }
255
+ free (A );
256
+ free (B );
257
+ free (C );
258
+ free (AA );
259
+ free (BB );
260
+ free (DD );
261
+ free (CC );
262
+ }
263
+
264
+ if (ret != 0 )
265
+ fprintf (stderr , "FATAL ERROR SBGEMV - Return code: %d\n" , ret );
192
266
return ret ;
193
267
}
0 commit comments