@@ -205,15 +205,14 @@ main (int argc, char *argv[])
205
205
for (l = 0 ; l < 2 ; l ++ ) { // l = 1 to test inc_x & inc_y not equal to one.
206
206
for (x = 1 ; x <= loop ; x ++ )
207
207
{
208
- m = l + 1 ;
209
- k = (x == 0 ) ? 0 : m ;
208
+ k = (x == 0 ) ? 0 : l + 1 ;
210
209
float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
211
- float * B = (float * )malloc_safe (x * sizeof (FLOAT ) * m );
212
- float * C = (float * )malloc_safe (x * sizeof (FLOAT ) * m );
210
+ float * B = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
211
+ float * C = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
213
212
bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
214
- bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ) * m );
213
+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ) << l );
215
214
float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
216
- float * CC = (float * )malloc_safe (x * sizeof (FLOAT ) * m );
215
+ float * CC = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
217
216
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
218
217
(DD == NULL ) || (CC == NULL ))
219
218
return 1 ;
@@ -228,9 +227,9 @@ main (int argc, char *argv[])
228
227
sbstobf16_ (& one , & A [j * x + i ], & one , & atmp , & one );
229
228
AA [j * x + i ].v = atmp ;
230
229
}
231
- B [j * m ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
232
- sbstobf16_ (& one , & B [j * m ], & one , & btmp , & one );
233
- BB [j * m ].v = btmp ;
230
+ B [j << l ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
231
+ sbstobf16_ (& one , & B [j << l ], & one , & btmp , & one );
232
+ BB [j << l ].v = btmp ;
234
233
}
235
234
for (y = 0 ; y < 2 ; y ++ )
236
235
{
@@ -240,25 +239,25 @@ main (int argc, char *argv[])
240
239
transA = 'T' ;
241
240
}
242
241
243
- memset (CC , 0 , x * m * sizeof (FLOAT ));
242
+ memset (CC , 0 , x * sizeof (FLOAT ) << l );
244
243
memset (DD , 0 , x * sizeof (FLOAT ));
245
- memset (C , 0 , x * m * sizeof (FLOAT ));
244
+ memset (C , 0 , x * sizeof (FLOAT ) << l );
246
245
247
246
SGEMV (& transA , & x , & x , & alpha , A , & x , B , & k , & beta , C , & k );
248
247
SBGEMV (& transA , & x , & x , & alpha , (bfloat16 * ) AA , & x , (bfloat16 * ) BB , & k , & beta , CC , & k );
249
248
250
249
for (j = 0 ; j < x ; j ++ )
251
250
for (i = 0 ; i < x ; i ++ )
252
251
if (transA == 'N' ) {
253
- DD [i ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [j * m ]);
252
+ DD [i ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [j << l ]);
254
253
} else if (transA == 'T' ) {
255
- DD [j ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [i * m ]);
254
+ DD [j ] += float16to32 (AA [j * x + i ]) * float16to32 (BB [i << l ]);
256
255
}
257
256
258
257
for (j = 0 ; j < x ; j ++ ) {
259
- if (fabs (CC [j * m ] - C [j * m ]) > 1.0 )
258
+ if (fabs (CC [j << l ] - C [j << l ]) > 1.0 )
260
259
ret ++ ;
261
- if (fabs (CC [j * m ] - DD [j ]) > 1.0 )
260
+ if (fabs (CC [j << l ] - DD [j ]) > 1.0 )
262
261
ret ++ ;
263
262
}
264
263
}
0 commit comments