Skip to content

Commit c23897f

Browse files
author
Chip Kerchner
committed
Add GEMV testing to SBGEMx vs SGEMx testing.
1 parent 6452f7b commit c23897f

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

test/compare_sgemm_sbgemm.c

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
#include "../common.h"
3030
#define SGEMM BLASFUNC(sgemm)
3131
#define SBGEMM BLASFUNC(sbgemm)
32+
#define SGEMV BLASFUNC(sgemv)
33+
#define SBGEMV BLASFUNC(sbgemv)
3234
typedef union
3335
{
3436
unsigned short v;
@@ -187,7 +189,79 @@ main (int argc, char *argv[])
187189
free(CC);
188190
}
189191

190-
if (ret != 0)
192+
if (ret != 0) {
191193
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
194+
return ret;
195+
}
196+
197+
k = 1;
198+
for (x = 1; x <= loop; x++)
199+
{
200+
float *A = (float *)malloc(x * x * sizeof(FLOAT));
201+
float *B = (float *)malloc(x * sizeof(FLOAT));
202+
float *C = (float *)malloc(x * sizeof(FLOAT));
203+
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
204+
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
205+
float *DD = (float *)malloc(x * sizeof(FLOAT));
206+
float *CC = (float *)malloc(x * sizeof(FLOAT));
207+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
208+
(DD == NULL) || (CC == NULL))
209+
return 1;
210+
bfloat16 atmp, btmp;
211+
blasint one = 1;
212+
213+
for (j = 0; j < x; j++)
214+
{
215+
for (i = 0; i < x; i++)
216+
{
217+
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
218+
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
219+
AA[j * x + i].v = atmp;
220+
}
221+
B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
222+
sbstobf16_(&one, &B[j], &one, &btmp, &one);
223+
BB[j].v = btmp;
224+
}
225+
for (y = 0; y < 2; y++)
226+
{
227+
if (y == 0) {
228+
transA = 'N';
229+
} else {
230+
transA = 'T';
231+
}
232+
233+
memset(CC, 0, x * sizeof(FLOAT));
234+
memset(DD, 0, x * sizeof(FLOAT));
235+
memset(C, 0, x * sizeof(FLOAT));
236+
237+
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
238+
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
239+
240+
for (j = 0; j < x; j++)
241+
for (i = 0; i < x; i++)
242+
if (transA == 'N') {
243+
DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]);
244+
} else if (transA == 'T') {
245+
DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]);
246+
}
247+
248+
for (j = 0; j < x; j++) {
249+
if (fabs (CC[j] - C[j]) > 1.0)
250+
ret++;
251+
if (fabs (CC[j] - DD[j]) > 1.0)
252+
ret++;
253+
}
254+
}
255+
free(A);
256+
free(B);
257+
free(C);
258+
free(AA);
259+
free(BB);
260+
free(DD);
261+
free(CC);
262+
}
263+
264+
if (ret != 0)
265+
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
192266
return ret;
193267
}

0 commit comments

Comments
 (0)