@@ -81,19 +81,48 @@ float16to32 (bfloat16_bits f16)
81
81
return f32 .v ;
82
82
}
83
83
84
+ #define SBGEMM_LARGEST 256
85
+
84
86
int
85
87
main (int argc , char * argv [])
86
88
{
87
89
blasint m , n , k ;
88
90
int i , j , l ;
89
91
blasint x , y ;
90
92
int ret = 0 ;
91
- int loop = 100 ;
93
+ int loop = SBGEMM_LARGEST ;
92
94
char transA = 'N' , transB = 'N' ;
93
95
float alpha = 1.0 , beta = 0.0 ;
94
96
95
97
for (x = 0 ; x <= loop ; x ++ )
96
98
{
99
+ if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
100
+ m = k = n = x ;
101
+ float * A = (float * )malloc (m * k * sizeof (FLOAT ));
102
+ float * B = (float * )malloc (k * n * sizeof (FLOAT ));
103
+ float * C = (float * )malloc (m * n * sizeof (FLOAT ));
104
+ bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
105
+ bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
106
+ float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
107
+ float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
108
+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
109
+ (DD == NULL ) || (CC == NULL ))
110
+ return 1 ;
111
+ bfloat16 atmp ,btmp ;
112
+ blasint one = 1 ;
113
+
114
+ for (j = 0 ; j < m ; j ++ )
115
+ {
116
+ for (i = 0 ; i < n ; i ++ )
117
+ {
118
+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
119
+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
120
+ sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
121
+ sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
122
+ AA [j * k + i ].v = atmp ;
123
+ BB [j * k + i ].v = btmp ;
124
+ }
125
+ }
97
126
for (y = 0 ; y < 4 ; y ++ )
98
127
{
99
128
if ((y == 0 ) || (y == 2 )) {
@@ -106,34 +135,16 @@ main (int argc, char *argv[])
106
135
} else {
107
136
transB = 'T' ;
108
137
}
109
- m = k = n = x ;
110
- float A [m * k ];
111
- float B [k * n ];
112
- float C [m * n ];
113
- bfloat16_bits AA [m * k ], BB [k * n ];
114
- float DD [m * n ], CC [m * n ];
115
- bfloat16 atmp ,btmp ;
116
- blasint one = 1 ;
117
138
118
- for (j = 0 ; j < m ; j ++ )
119
- {
120
- for (i = 0 ; i < m ; i ++ )
121
- {
122
- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
123
- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
124
- C [j * k + i ] = 0 ;
125
- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
126
- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
127
- AA [j * k + i ].v = atmp ;
128
- BB [j * k + i ].v = btmp ;
129
- CC [j * k + i ] = 0 ;
130
- DD [j * k + i ] = 0 ;
131
- }
132
- }
139
+ memset (CC , 0 , m * n * sizeof (FLOAT ));
140
+ memset (DD , 0 , m * n * sizeof (FLOAT ));
141
+ memset (C , 0 , m * n * sizeof (FLOAT ));
142
+
133
143
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
134
144
& m , B , & k , & beta , C , & m );
135
145
SBGEMM (& transA , & transB , & m , & n , & k , & alpha , (bfloat16 * ) AA ,
136
146
& m , (bfloat16 * )BB , & k , & beta , CC , & m );
147
+
137
148
for (i = 0 ; i < n ; i ++ )
138
149
for (j = 0 ; j < m ; j ++ )
139
150
if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
@@ -160,9 +171,16 @@ main (int argc, char *argv[])
160
171
}
161
172
for (i = 0 ; i < n ; i ++ )
162
173
for (j = 0 ; j < m ; j ++ )
163
- if (CC [i * m + j ] != DD [i * m + j ])
174
+ if (fabs ( CC [i * m + j ] - DD [i * m + j ]) > 1.0 )
164
175
ret ++ ;
165
176
}
177
+ free (A );
178
+ free (B );
179
+ free (C );
180
+ free (AA );
181
+ free (BB );
182
+ free (DD );
183
+ free (CC );
166
184
}
167
185
168
186
if (ret != 0 )
0 commit comments