Skip to content

Commit 6452f7b

Browse files
authored
Merge pull request #4873 from ChipKerchner/fixSBGEMMDefaults
[POWER] Problem with multi-threaded SBGEMM
2 parents ca7777d + 3122674 commit 6452f7b

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

param.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,8 +2637,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26372637
#undef SBGEMM_DEFAULT_Q
26382638
#define SBGEMM_DEFAULT_UNROLL_M 16
26392639
#define SBGEMM_DEFAULT_UNROLL_N 8
2640-
#define SBGEMM_DEFAULT_P 832
2641-
#define SBGEMM_DEFAULT_Q 1026
2640+
#define SBGEMM_DEFAULT_P 512
2641+
#define SBGEMM_DEFAULT_Q 1024
26422642
#define SBGEMM_DEFAULT_R 4096
26432643
#endif
26442644

test/compare_sgemm_sbgemm.c

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,54 @@ float16to32 (bfloat16_bits f16)
8181
return f32.v;
8282
}
8383

84+
#define SBGEMM_LARGEST 256
85+
8486
int
8587
main (int argc, char *argv[])
8688
{
8789
blasint m, n, k;
8890
int i, j, l;
8991
blasint x, y;
9092
int ret = 0;
91-
int loop = 100;
93+
int loop = SBGEMM_LARGEST;
9294
char transA = 'N', transB = 'N';
9395
float alpha = 1.0, beta = 0.0;
9496

9597
for (x = 0; x <= loop; x++)
9698
{
99+
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
100+
m = k = n = x;
101+
float *A = (float *)malloc(m * k * sizeof(FLOAT));
102+
float *B = (float *)malloc(k * n * sizeof(FLOAT));
103+
float *C = (float *)malloc(m * n * sizeof(FLOAT));
104+
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
105+
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
106+
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
107+
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
108+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
109+
(DD == NULL) || (CC == NULL))
110+
return 1;
111+
bfloat16 atmp,btmp;
112+
blasint one=1;
113+
114+
for (j = 0; j < m; j++)
115+
{
116+
for (i = 0; i < k; i++)
117+
{
118+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
119+
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
120+
AA[j * k + i].v = atmp;
121+
}
122+
}
123+
for (j = 0; j < n; j++)
124+
{
125+
for (i = 0; i < k; i++)
126+
{
127+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
128+
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
129+
BB[j * k + i].v = btmp;
130+
}
131+
}
97132
for (y = 0; y < 4; y++)
98133
{
99134
if ((y == 0) || (y == 2)) {
@@ -106,40 +141,19 @@ main (int argc, char *argv[])
106141
} else {
107142
transB = 'T';
108143
}
109-
m = k = n = x;
110-
float A[m * k];
111-
float B[k * n];
112-
float C[m * n];
113-
bfloat16_bits AA[m * k], BB[k * n];
114-
float DD[m * n], CC[m * n];
115-
bfloat16 atmp,btmp;
116-
blasint one=1;
117144

118-
for (j = 0; j < m; j++)
119-
{
120-
for (i = 0; i < m; i++)
121-
{
122-
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
123-
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
124-
C[j * k + i] = 0;
125-
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
126-
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
127-
AA[j * k + i].v = atmp;
128-
BB[j * k + i].v = btmp;
129-
CC[j * k + i] = 0;
130-
DD[j * k + i] = 0;
131-
}
132-
}
145+
memset(CC, 0, m * n * sizeof(FLOAT));
146+
memset(DD, 0, m * n * sizeof(FLOAT));
147+
memset(C, 0, m * n * sizeof(FLOAT));
148+
133149
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
134150
&m, B, &k, &beta, C, &m);
135151
SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
136152
&m, (bfloat16*)BB, &k, &beta, CC, &m);
153+
137154
for (i = 0; i < n; i++)
138155
for (j = 0; j < m; j++)
139-
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
140-
ret++;
141-
for (i = 0; i < n; i++)
142-
for (j = 0; j < m; j++)
156+
{
143157
for (l = 0; l < k; l++)
144158
if (transA == 'N' && transB == 'N')
145159
{
@@ -158,11 +172,19 @@ main (int argc, char *argv[])
158172
DD[i * m + j] +=
159173
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]);
160174
}
161-
for (i = 0; i < n; i++)
162-
for (j = 0; j < m; j++)
163-
if (CC[i * m + j] != DD[i * m + j])
175+
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
176+
ret++;
177+
if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0)
164178
ret++;
179+
}
165180
}
181+
free(A);
182+
free(B);
183+
free(C);
184+
free(AA);
185+
free(BB);
186+
free(DD);
187+
free(CC);
166188
}
167189

168190
if (ret != 0)

0 commit comments

Comments
 (0)