@@ -46,6 +46,27 @@ typedef union
46
46
} bits ;
47
47
} bfloat16_bits ;
48
48
49
+ typedef union
50
+ {
51
+ float v ;
52
+ struct
53
+ {
54
+ uint32_t m :23 ;
55
+ uint32_t e :8 ;
56
+ uint32_t s :1 ;
57
+ } bits ;
58
+ } float32_bits ;
59
+
60
+ float
61
+ float16to32 (bfloat16_bits f16 )
62
+ {
63
+ float32_bits f32 ;
64
+ f32 .bits .s = f16 .bits .s ;
65
+ f32 .bits .e = f16 .bits .e ;
66
+ f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
67
+ return f32 .v ;
68
+ }
69
+
49
70
int
50
71
main (int argc , char * argv [])
51
72
{
@@ -55,8 +76,6 @@ main (int argc, char *argv[])
55
76
int loop = 100 ;
56
77
char transA = 'N' , transB = 'N' ;
57
78
float alpha = 1.0 , beta = 0.0 ;
58
- char transa = 'N' ;
59
- char transb = 'N' ;
60
79
61
80
for (int x = 0 ; x <= loop ; x ++ )
62
81
{
@@ -65,30 +84,45 @@ main (int argc, char *argv[])
65
84
float B [k * n ];
66
85
float C [m * n ];
67
86
bfloat16_bits AA [m * k ], BB [k * n ];
68
- float CC [m * n ];
87
+ float DD [ m * n ], CC [m * n ];
69
88
70
89
for (int j = 0 ; j < m ; j ++ )
71
90
{
72
91
for (int i = 0 ; i < m ; i ++ )
73
92
{
74
- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
75
- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
93
+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
94
+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
76
95
C [j * k + i ] = 0 ;
77
96
AA [j * k + i ].v = * (uint32_t * ) & A [j * k + i ] >> 16 ;
78
97
BB [j * k + i ].v = * (uint32_t * ) & B [j * k + i ] >> 16 ;
79
98
CC [j * k + i ] = 0 ;
99
+ DD [j * k + i ] = 0 ;
80
100
}
81
101
}
82
102
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
83
- & m , B , & k , & beta , C , & m );
103
+ & m , B , & k , & beta , C , & m );
84
104
SHGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
85
- & m , BB , & k , & beta , CC , & m );
86
-
105
+ & m , BB , & k , & beta , CC , & m );
87
106
for (i = 0 ; i < n ; i ++ )
88
- for (j = 0 ; j < m ; j ++ )
89
- for (l = 0 ; l < k ; l ++ )
90
- if (fabs (CC [i * m + j ]- C [i * m + j ]) > 1.0 )
91
- ret ++ ;
107
+ for (j = 0 ; j < m ; j ++ )
108
+ for (l = 0 ; l < k ; l ++ )
109
+ if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
110
+ ret ++ ;
111
+ if (transA == 'N' && transB == 'N' )
112
+ {
113
+ for (i = 0 ; i < n ; i ++ )
114
+ for (j = 0 ; j < m ; j ++ )
115
+ for (l = 0 ; l < k ; l ++ )
116
+ {
117
+ DD [i * m + j ] +=
118
+ float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
119
+ }
120
+ for (i = 0 ; i < n ; i ++ )
121
+ for (j = 0 ; j < m ; j ++ )
122
+ for (l = 0 ; l < k ; l ++ )
123
+ if (CC [i * m + j ] != DD [i * m + j ])
124
+ ret ++ ;
125
+ }
92
126
}
93
127
if (ret != 0 )
94
128
fprintf (stderr , "FATAL ERROR SHGEMM - Return code: %d\n" , ret );
0 commit comments