Skip to content

Commit b173769

Browse files
author
Chip Kerchner
committed
Fix DEFAULTS in SBGEMM for POWER10. Also comparisons for SBGEMM unit test can be exactly due to epilison differences.
1 parent cd3945b commit b173769

File tree

2 files changed

+45
-27
lines changed

2 files changed

+45
-27
lines changed

param.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,8 +2637,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26372637
#undef SBGEMM_DEFAULT_Q
26382638
#define SBGEMM_DEFAULT_UNROLL_M 16
26392639
#define SBGEMM_DEFAULT_UNROLL_N 8
2640-
#define SBGEMM_DEFAULT_P 832
2641-
#define SBGEMM_DEFAULT_Q 1026
2640+
#define SBGEMM_DEFAULT_P 512
2641+
#define SBGEMM_DEFAULT_Q 1024
26422642
#define SBGEMM_DEFAULT_R 4096
26432643
#endif
26442644

test/compare_sgemm_sbgemm.c

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,48 @@ float16to32 (bfloat16_bits f16)
8181
return f32.v;
8282
}
8383

84+
#define SBGEMM_LARGEST 256
85+
8486
int
8587
main (int argc, char *argv[])
8688
{
8789
blasint m, n, k;
8890
int i, j, l;
8991
blasint x, y;
9092
int ret = 0;
91-
int loop = 100;
93+
int loop = SBGEMM_LARGEST;
9294
char transA = 'N', transB = 'N';
9395
float alpha = 1.0, beta = 0.0;
9496

9597
for (x = 0; x <= loop; x++)
9698
{
99+
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
100+
m = k = n = x;
101+
float *A = (float *)malloc(m * k * sizeof(FLOAT));
102+
float *B = (float *)malloc(k * n * sizeof(FLOAT));
103+
float *C = (float *)malloc(m * n * sizeof(FLOAT));
104+
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
105+
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
106+
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
107+
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
108+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
109+
(DD == NULL) || (CC == NULL))
110+
return 1;
111+
bfloat16 atmp,btmp;
112+
blasint one=1;
113+
114+
for (j = 0; j < m; j++)
115+
{
116+
for (i = 0; i < n; i++)
117+
{
118+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
119+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
120+
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
121+
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
122+
AA[j * k + i].v = atmp;
123+
BB[j * k + i].v = btmp;
124+
}
125+
}
97126
for (y = 0; y < 4; y++)
98127
{
99128
if ((y == 0) || (y == 2)) {
@@ -106,34 +135,16 @@ main (int argc, char *argv[])
106135
} else {
107136
transB = 'T';
108137
}
109-
m = k = n = x;
110-
float A[m * k];
111-
float B[k * n];
112-
float C[m * n];
113-
bfloat16_bits AA[m * k], BB[k * n];
114-
float DD[m * n], CC[m * n];
115-
bfloat16 atmp,btmp;
116-
blasint one=1;
117138

118-
for (j = 0; j < m; j++)
119-
{
120-
for (i = 0; i < m; i++)
121-
{
122-
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
123-
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
124-
C[j * k + i] = 0;
125-
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
126-
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
127-
AA[j * k + i].v = atmp;
128-
BB[j * k + i].v = btmp;
129-
CC[j * k + i] = 0;
130-
DD[j * k + i] = 0;
131-
}
132-
}
139+
memset(CC, 0, m * n * sizeof(FLOAT));
140+
memset(DD, 0, m * n * sizeof(FLOAT));
141+
memset(C, 0, m * n * sizeof(FLOAT));
142+
133143
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
134144
&m, B, &k, &beta, C, &m);
135145
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
136146
&m, (bfloat16*)BB, &k, &beta, CC, &m);
147+
137148
for (i = 0; i < n; i++)
138149
for (j = 0; j < m; j++)
139150
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
@@ -160,9 +171,16 @@ main (int argc, char *argv[])
160171
}
161172
for (i = 0; i < n; i++)
162173
for (j = 0; j < m; j++)
163-
if (CC[i * m + j] != DD[i * m + j])
174+
if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0)
164175
ret++;
165176
}
177+
free(A);
178+
free(B);
179+
free(C);
180+
free(AA);
181+
free(BB);
182+
free(DD);
183+
free(CC);
166184
}
167185

168186
if (ret != 0)

0 commit comments

Comments
 (0)