Skip to content

Commit 868aa85

Browse files
author
Chip Kerchner
committed
Change malloc zero to return one byte and update the SBGEMM test to again use sizes of zero.
1 parent b1802f4 commit 868aa85

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

test/compare_sgemm_sbgemm.c

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)
8585

8686
#define SBGEMM_LARGEST 256
8787

88+
void *malloc_safe(size_t size)
89+
{
90+
if (size == 0)
91+
return malloc(1);
92+
else
93+
return malloc(size);
94+
}
95+
8896
int
8997
main (int argc, char *argv[])
9098
{
@@ -96,17 +104,17 @@ main (int argc, char *argv[])
96104
char transA = 'N', transB = 'N';
97105
float alpha = 1.0, beta = 0.0;
98106

99-
for (x = 1; x <= loop; x++)
107+
for (x = 0; x <= loop; x++)
100108
{
101109
if ((x > 100) && (x != SBGEMM_LARGEST)) continue;
102110
m = k = n = x;
103-
float *A = (float *)malloc(m * k * sizeof(FLOAT));
104-
float *B = (float *)malloc(k * n * sizeof(FLOAT));
105-
float *C = (float *)malloc(m * n * sizeof(FLOAT));
106-
bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits));
107-
bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits));
108-
float *DD = (float *)malloc(m * n * sizeof(FLOAT));
109-
float *CC = (float *)malloc(m * n * sizeof(FLOAT));
111+
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
112+
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
113+
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
114+
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
115+
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
116+
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
117+
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
110118
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
111119
(DD == NULL) || (CC == NULL))
112120
return 1;
@@ -195,15 +203,15 @@ main (int argc, char *argv[])
195203
}
196204

197205
k = 1;
198-
for (x = 1; x <= loop; x++)
206+
for (x = 0; x <= loop; x++)
199207
{
200-
float *A = (float *)malloc(x * x * sizeof(FLOAT));
201-
float *B = (float *)malloc(x * sizeof(FLOAT));
202-
float *C = (float *)malloc(x * sizeof(FLOAT));
203-
bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits));
204-
bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits));
205-
float *DD = (float *)malloc(x * sizeof(FLOAT));
206-
float *CC = (float *)malloc(x * sizeof(FLOAT));
208+
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
209+
float *B = (float *)malloc_safe(x * sizeof(FLOAT));
210+
float *C = (float *)malloc_safe(x * sizeof(FLOAT));
211+
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
212+
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits));
213+
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
214+
float *CC = (float *)malloc_safe(x * sizeof(FLOAT));
207215
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
208216
(DD == NULL) || (CC == NULL))
209217
return 1;

0 commit comments

Comments
 (0)