@@ -81,19 +81,54 @@ float16to32 (bfloat16_bits f16)
81
81
return f32 .v ;
82
82
}
83
83
84
+ #define SBGEMM_LARGEST 256
85
+
84
86
int
85
87
main (int argc , char * argv [])
86
88
{
87
89
blasint m , n , k ;
88
90
int i , j , l ;
89
91
blasint x , y ;
90
92
int ret = 0 ;
91
- int loop = 100 ;
93
+ int loop = SBGEMM_LARGEST ;
92
94
char transA = 'N' , transB = 'N' ;
93
95
float alpha = 1.0 , beta = 0.0 ;
94
96
95
97
for (x = 0 ; x <= loop ; x ++ )
96
98
{
99
+ if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
100
+ m = k = n = x ;
101
+ float * A = (float * )malloc (m * k * sizeof (FLOAT ));
102
+ float * B = (float * )malloc (k * n * sizeof (FLOAT ));
103
+ float * C = (float * )malloc (m * n * sizeof (FLOAT ));
104
+ bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
105
+ bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
106
+ float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
107
+ float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
108
+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
109
+ (DD == NULL ) || (CC == NULL ))
110
+ return 1 ;
111
+ bfloat16 atmp ,btmp ;
112
+ blasint one = 1 ;
113
+
114
+ for (j = 0 ; j < m ; j ++ )
115
+ {
116
+ for (i = 0 ; i < k ; i ++ )
117
+ {
118
+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
119
+ sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
120
+ AA [j * k + i ].v = atmp ;
121
+ }
122
+ }
123
+ for (j = 0 ; j < n ; j ++ )
124
+ {
125
+ for (i = 0 ; i < k ; i ++ )
126
+ {
127
+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
128
+ sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
129
+ BB [j * k + i ].v = btmp ;
130
+ }
131
+ }
97
132
for (y = 0 ; y < 4 ; y ++ )
98
133
{
99
134
if ((y == 0 ) || (y == 2 )) {
@@ -106,40 +141,19 @@ main (int argc, char *argv[])
106
141
} else {
107
142
transB = 'T' ;
108
143
}
109
- m = k = n = x ;
110
- float A [m * k ];
111
- float B [k * n ];
112
- float C [m * n ];
113
- bfloat16_bits AA [m * k ], BB [k * n ];
114
- float DD [m * n ], CC [m * n ];
115
- bfloat16 atmp ,btmp ;
116
- blasint one = 1 ;
117
144
118
- for (j = 0 ; j < m ; j ++ )
119
- {
120
- for (i = 0 ; i < m ; i ++ )
121
- {
122
- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
123
- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
124
- C [j * k + i ] = 0 ;
125
- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
126
- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
127
- AA [j * k + i ].v = atmp ;
128
- BB [j * k + i ].v = btmp ;
129
- CC [j * k + i ] = 0 ;
130
- DD [j * k + i ] = 0 ;
131
- }
132
- }
145
+ memset (CC , 0 , m * n * sizeof (FLOAT ));
146
+ memset (DD , 0 , m * n * sizeof (FLOAT ));
147
+ memset (C , 0 , m * n * sizeof (FLOAT ));
148
+
133
149
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
134
150
& m , B , & k , & beta , C , & m );
135
151
SBGEMM (& transA , & transB , & m , & n , & k , & alpha , (bfloat16 * ) AA ,
136
152
& m , (bfloat16 * )BB , & k , & beta , CC , & m );
153
+
137
154
for (i = 0 ; i < n ; i ++ )
138
155
for (j = 0 ; j < m ; j ++ )
139
- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
140
- ret ++ ;
141
- for (i = 0 ; i < n ; i ++ )
142
- for (j = 0 ; j < m ; j ++ )
156
+ {
143
157
for (l = 0 ; l < k ; l ++ )
144
158
if (transA == 'N' && transB == 'N' )
145
159
{
@@ -158,11 +172,19 @@ main (int argc, char *argv[])
158
172
DD [i * m + j ] +=
159
173
float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
160
174
}
161
- for ( i = 0 ; i < n ; i ++ )
162
- for ( j = 0 ; j < m ; j ++ )
163
- if (CC [i * m + j ] != DD [i * m + j ])
175
+ if ( fabs ( CC [ i * m + j ] - C [ i * m + j ]) > 1.0 )
176
+ ret ++ ;
177
+ if (fabs ( CC [i * m + j ] - DD [i * m + j ]) > 1.0 )
164
178
ret ++ ;
179
+ }
165
180
}
181
+ free (A );
182
+ free (B );
183
+ free (C );
184
+ free (AA );
185
+ free (BB );
186
+ free (DD );
187
+ free (CC );
166
188
}
167
189
168
190
if (ret != 0 )
0 commit comments