@@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)
85
85
86
86
#define SBGEMM_LARGEST 256
87
87
88
+ void * malloc_safe (size_t size )
89
+ {
90
+ if (size == 0 )
91
+ return malloc (1 );
92
+ else
93
+ return malloc (size );
94
+ }
95
+
88
96
int
89
97
main (int argc , char * argv [])
90
98
{
@@ -100,13 +108,13 @@ main (int argc, char *argv[])
100
108
{
101
109
if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
102
110
m = k = n = x ;
103
- float * A = (float * )malloc (m * k * sizeof (FLOAT ));
104
- float * B = (float * )malloc (k * n * sizeof (FLOAT ));
105
- float * C = (float * )malloc (m * n * sizeof (FLOAT ));
106
- bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
107
- bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
108
- float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
109
- float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
111
+ float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
112
+ float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
113
+ float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
114
+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (m * k * sizeof (bfloat16_bits ));
115
+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (k * n * sizeof (bfloat16_bits ));
116
+ float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
117
+ float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
110
118
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
111
119
(DD == NULL ) || (CC == NULL ))
112
120
return 1 ;
@@ -194,16 +202,16 @@ main (int argc, char *argv[])
194
202
return ret ;
195
203
}
196
204
197
- k = 1 ;
198
205
for (x = 1 ; x <= loop ; x ++ )
199
206
{
200
- float * A = (float * )malloc (x * x * sizeof (FLOAT ));
201
- float * B = (float * )malloc (x * sizeof (FLOAT ));
202
- float * C = (float * )malloc (x * sizeof (FLOAT ));
203
- bfloat16_bits * AA = (bfloat16_bits * )malloc (x * x * sizeof (bfloat16_bits ));
204
- bfloat16_bits * BB = (bfloat16_bits * )malloc (x * sizeof (bfloat16_bits ));
205
- float * DD = (float * )malloc (x * sizeof (FLOAT ));
206
- float * CC = (float * )malloc (x * sizeof (FLOAT ));
207
+ k = (x == 0 ) ? 0 : 1 ;
208
+ float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
209
+ float * B = (float * )malloc_safe (x * sizeof (FLOAT ));
210
+ float * C = (float * )malloc_safe (x * sizeof (FLOAT ));
211
+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
212
+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ));
213
+ float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
214
+ float * CC = (float * )malloc_safe (x * sizeof (FLOAT ));
207
215
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
208
216
(DD == NULL ) || (CC == NULL ))
209
217
return 1 ;
0 commit comments