@@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)
85
85
86
86
#define SBGEMM_LARGEST 256
87
87
88
+ void * malloc_safe (size_t size )
89
+ {
90
+ if (size == 0 )
91
+ return malloc (1 );
92
+ else
93
+ return malloc (size );
94
+ }
95
+
88
96
int
89
97
main (int argc , char * argv [])
90
98
{
@@ -96,17 +104,17 @@ main (int argc, char *argv[])
96
104
char transA = 'N' , transB = 'N' ;
97
105
float alpha = 1.0 , beta = 0.0 ;
98
106
99
- for (x = 1 ; x <= loop ; x ++ )
107
+ for (x = 0 ; x <= loop ; x ++ )
100
108
{
101
109
if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
102
110
m = k = n = x ;
103
- float * A = (float * )malloc (m * k * sizeof (FLOAT ));
104
- float * B = (float * )malloc (k * n * sizeof (FLOAT ));
105
- float * C = (float * )malloc (m * n * sizeof (FLOAT ));
106
- bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
107
- bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
108
- float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
109
- float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
111
+ float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
112
+ float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
113
+ float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
114
+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (m * k * sizeof (bfloat16_bits ));
115
+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (k * n * sizeof (bfloat16_bits ));
116
+ float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
117
+ float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
110
118
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
111
119
(DD == NULL ) || (CC == NULL ))
112
120
return 1 ;
@@ -195,15 +203,15 @@ main (int argc, char *argv[])
195
203
}
196
204
197
205
k = 1 ;
198
- for (x = 1 ; x <= loop ; x ++ )
206
+ for (x = 0 ; x <= loop ; x ++ )
199
207
{
200
- float * A = (float * )malloc (x * x * sizeof (FLOAT ));
201
- float * B = (float * )malloc (x * sizeof (FLOAT ));
202
- float * C = (float * )malloc (x * sizeof (FLOAT ));
203
- bfloat16_bits * AA = (bfloat16_bits * )malloc (x * x * sizeof (bfloat16_bits ));
204
- bfloat16_bits * BB = (bfloat16_bits * )malloc (x * sizeof (bfloat16_bits ));
205
- float * DD = (float * )malloc (x * sizeof (FLOAT ));
206
- float * CC = (float * )malloc (x * sizeof (FLOAT ));
208
+ float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
209
+ float * B = (float * )malloc_safe (x * sizeof (FLOAT ));
210
+ float * C = (float * )malloc_safe (x * sizeof (FLOAT ));
211
+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
212
+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ));
213
+ float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
214
+ float * CC = (float * )malloc_safe (x * sizeof (FLOAT ));
207
215
if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
208
216
(DD == NULL ) || (CC == NULL ))
209
217
return 1 ;
0 commit comments