Skip to content

Commit 13c2888

Browse files
authored
Update "cosmetic fixes for non-C99 compilers"
1 parent 28915ee commit 13c2888

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

test/compare_sgemm_shgemm.c

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,27 @@ typedef union
4646
} bits;
4747
} bfloat16_bits;
4848

49+
typedef union
50+
{
51+
float v;
52+
struct
53+
{
54+
uint32_t m:23;
55+
uint32_t e:8;
56+
uint32_t s:1;
57+
} bits;
58+
} float32_bits;
59+
60+
float
61+
float16to32 (bfloat16_bits f16)
62+
{
63+
float32_bits f32;
64+
f32.bits.s = f16.bits.s;
65+
f32.bits.e = f16.bits.e;
66+
f32.bits.m = (uint32_t) f16.bits.m << 16;
67+
return f32.v;
68+
}
69+
4970
int
5071
main (int argc, char *argv[])
5172
{
@@ -56,8 +77,6 @@ main (int argc, char *argv[])
5677
int loop = 100;
5778
char transA = 'N', transB = 'N';
5879
float alpha = 1.0, beta = 0.0;
59-
char transa = 'N';
60-
char transb = 'N';
6180

6281
for (x = 0; x <= loop; x++)
6382
{
@@ -66,30 +85,45 @@ main (int argc, char *argv[])
6685
float B[k * n];
6786
float C[m * n];
6887
bfloat16_bits AA[m * k], BB[k * n];
69-
float CC[m * n];
88+
float DD[m * n], CC[m * n];
7089

7190
for (j = 0; j < m; j++)
7291
{
7392
for (i = 0; i < m; i++)
7493
{
75-
A[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
76-
B[j * k + i] = ((FLOAT) rand() / (FLOAT) RAND_MAX) + 0.5;
94+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
95+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
7796
C[j * k + i] = 0;
7897
AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
7998
BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
8099
CC[j * k + i] = 0;
100+
DD[j * k + i] = 0;
81101
}
82102
}
83103
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
84-
&m, B, &k, &beta, C, &m);
104+
&m, B, &k, &beta, C, &m);
85105
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
86-
&m, BB, &k, &beta, CC, &m);
87-
106+
&m, BB, &k, &beta, CC, &m);
88107
for (i = 0; i < n; i++)
89-
for (j = 0; j < m; j++)
90-
for (l = 0; l < k; l++)
91-
if (fabs(CC[i * m + j]-C[i * m + j]) > 1.0)
92-
ret++;
108+
for (j = 0; j < m; j++)
109+
for (l = 0; l < k; l++)
110+
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
111+
ret++;
112+
if (transA == 'N' && transB == 'N')
113+
{
114+
for (i = 0; i < n; i++)
115+
for (j = 0; j < m; j++)
116+
for (l = 0; l < k; l++)
117+
{
118+
DD[i * m + j] +=
119+
float16to32 (AA[l * m + j]) * float16to32 (BB[l + k * i]);
120+
}
121+
for (i = 0; i < n; i++)
122+
for (j = 0; j < m; j++)
123+
for (l = 0; l < k; l++)
124+
if (CC[i * m + j] != DD[i * m + j])
125+
ret++;
126+
}
93127
}
94128
if (ret != 0)
95129
fprintf (stderr, "FATAL ERROR SHGEMM - Return code: %d\n", ret);

0 commit comments

Comments
 (0)