Skip to content

Commit f94c53e

Browse files
authored
Merge pull request #2612 from RajalakshmiSR/testshgemm
Improve shgemm test
2 parents 4fffa55 + 8efba9b commit f94c53e

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

test/compare_sgemm_shgemm.c

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,27 @@ typedef union
4646
} bits;
4747
} bfloat16_bits;
4848

49+
typedef union
50+
{
51+
float v;
52+
struct
53+
{
54+
uint32_t m:23;
55+
uint32_t e:8;
56+
uint32_t s:1;
57+
} bits;
58+
} float32_bits;
59+
60+
float
61+
float16to32 (bfloat16_bits f16)
62+
{
63+
float32_bits f32;
64+
f32.bits.s = f16.bits.s;
65+
f32.bits.e = f16.bits.e;
66+
f32.bits.m = (uint32_t) f16.bits.m << 16;
67+
return f32.v;
68+
}
69+
4970
int
5071
main (int argc, char *argv[])
5172
{
@@ -55,8 +76,6 @@ main (int argc, char *argv[])
5576
int loop = 100;
5677
char transA = 'N', transB = 'N';
5778
float alpha = 1.0, beta = 0.0;
58-
char transa = 'N';
59-
char transb = 'N';
6079

6180
for (int x = 0; x <= loop; x++)
6281
{
@@ -65,30 +84,45 @@ main (int argc, char *argv[])
6584
float B[k * n];
6685
float C[m * n];
6786
bfloat16_bits AA[m * k], BB[k * n];
68-
float CC[m * n];
87+
float DD[m * n], CC[m * n];
6988

7089
for (int j = 0; j < m; j++)
7190
{
7291
for (int i = 0; i < m; i++)
7392
{
74-
A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
75-
B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
93+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
94+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
7695
C[j * k + i] = 0;
7796
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
7897
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
7998
CC[j * k + i] = 0;
99+
DD[j * k + i] = 0;
80100
}
81101
}
82102
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
83-
&m, B, &k, &beta, C, &m);
103+
&m, B, &k, &beta, C, &m);
84104
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
85-
&m, BB, &k, &beta, CC, &m);
86-
105+
&m, BB, &k, &beta, CC, &m);
87106
for (i = 0; i < n; i++)
88-
for (j = 0; j < m; j++)
89-
for (l = 0; l < k; l++)
90-
if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0)
91-
ret++;
107+
for (j = 0; j < m; j++)
108+
for (l = 0; l < k; l++)
109+
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
110+
ret++;
111+
if (transA == 'N' && transB == 'N')
112+
{
113+
for (i = 0; i < n; i++)
114+
for (j = 0; j < m; j++)
115+
for (l = 0; l < k; l++)
116+
{
117+
DD[i * m + j] +=
118+
float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]);
119+
}
120+
for (i = 0; i < n; i++)
121+
for (j = 0; j < m; j++)
122+
for (l = 0; l < k; l++)
123+
if (CC[i * m + j] != DD[i * m + j])
124+
ret++;
125+
}
92126
}
93127
if (ret != 0)
94128
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret);

0 commit comments

Comments
 (0)