@@ -100,7 +100,7 @@ int main(int argc, char *argv[])
100
100
SGEMV (& transA , & x , & x , & alpha , A , & x , B , & k , & beta , C , & k );
101
101
BGEMV (& transA , & x , & x , & alpha_bf16 , AA , & x , BB , & k , & beta_bf16 , CC , & k );
102
102
103
- for (int i = 0 ; i < x ; i ++ )
103
+ for (i = 0 ; i < x ; i ++ )
104
104
DD [i ] *= beta ;
105
105
106
106
for (j = 0 ; j < x ; j ++ )
@@ -118,14 +118,18 @@ int main(int argc, char *argv[])
118
118
{
119
119
if (!is_close (float16to32 (CC [j << l ]), truncate_float32_to_bfloat16 (C [j << l ]), 0.01 , 0.001 ))
120
120
{
121
- printf ("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n" ,
121
+ #ifdef DEBUG
122
+ printf ("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n" ,
122
123
transA , alpha , beta , i , j , k , float16to32 (CC [j << l ]), truncate_float32_to_bfloat16 (C [j << l ]));
124
+ #endif
123
125
ret ++ ;
124
126
}
125
127
if (!is_close (float16to32 (CC [j << l ]), truncate_float32_to_bfloat16 (DD [j ]), 0.001 , 0.0001 ))
126
128
{
127
- printf ("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n" ,
129
+ #ifdef DEBUG
130
+ printf ("Mismatch at trans=%c, alpha=%.2f, beta=%.2f, i=%d, j=%d, k=%ld: CC=%.6f, C=%.6f\n" ,
128
131
transA , alpha , beta , i , j , k , float16to32 (CC [j << l ]), truncate_float32_to_bfloat16 (DD [j ]));
132
+ #endif
129
133
ret ++ ;
130
134
}
131
135
}
0 commit comments