@@ -46,6 +46,27 @@ typedef union
46
46
} bits ;
47
47
} bfloat16_bits ;
48
48
49
+ typedef union
50
+ {
51
+ float v ;
52
+ struct
53
+ {
54
+ uint32_t m :23 ;
55
+ uint32_t e :8 ;
56
+ uint32_t s :1 ;
57
+ } bits ;
58
+ } float32_bits ;
59
+
60
+ float
61
+ float16to32 (bfloat16_bits f16 )
62
+ {
63
+ float32_bits f32 ;
64
+ f32 .bits .s = f16 .bits .s ;
65
+ f32 .bits .e = f16 .bits .e ;
66
+ f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
67
+ return f32 .v ;
68
+ }
69
+
49
70
int
50
71
main (int argc , char * argv [])
51
72
{
@@ -56,8 +77,6 @@ main (int argc, char *argv[])
56
77
int loop = 100 ;
57
78
char transA = 'N' , transB = 'N' ;
58
79
float alpha = 1.0 , beta = 0.0 ;
59
- char transa = 'N' ;
60
- char transb = 'N' ;
61
80
62
81
for (x = 0 ; x <= loop ; x ++ )
63
82
{
@@ -66,30 +85,45 @@ main (int argc, char *argv[])
66
85
float B [k * n ];
67
86
float C [m * n ];
68
87
bfloat16_bits AA [m * k ], BB [k * n ];
69
- float CC [m * n ];
88
+ float DD [ m * n ], CC [m * n ];
70
89
71
90
for (j = 0 ; j < m ; j ++ )
72
91
{
73
92
for (i = 0 ; i < m ; i ++ )
74
93
{
75
- A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
76
- B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
94
+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
95
+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
77
96
C [j * k + i ] = 0 ;
78
97
AA [j * k + i ].v = * (uint32_t * ) & A [j * k + i ] >> 16 ;
79
98
BB [j * k + i ].v = * (uint32_t * ) & B [j * k + i ] >> 16 ;
80
99
CC [j * k + i ] = 0 ;
100
+ DD [j * k + i ] = 0 ;
81
101
}
82
102
}
83
103
SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
84
- & m , B , & k , & beta , C , & m );
104
+ & m , B , & k , & beta , C , & m );
85
105
SHGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
86
- & m , BB , & k , & beta , CC , & m );
87
-
106
+ & m , BB , & k , & beta , CC , & m );
88
107
for (i = 0 ; i < n ; i ++ )
89
- for (j = 0 ; j < m ; j ++ )
90
- for (l = 0 ; l < k ; l ++ )
91
- if (fabs (CC [i * m + j ]- C [i * m + j ]) > 1.0 )
92
- ret ++ ;
108
+ for (j = 0 ; j < m ; j ++ )
109
+ for (l = 0 ; l < k ; l ++ )
110
+ if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
111
+ ret ++ ;
112
+ if (transA == 'N' && transB == 'N' )
113
+ {
114
+ for (i = 0 ; i < n ; i ++ )
115
+ for (j = 0 ; j < m ; j ++ )
116
+ for (l = 0 ; l < k ; l ++ )
117
+ {
118
+ DD [i * m + j ] +=
119
+ float16to32 (AA [l * m + j ]) * float16to32 (BB [l + k * i ]);
120
+ }
121
+ for (i = 0 ; i < n ; i ++ )
122
+ for (j = 0 ; j < m ; j ++ )
123
+ for (l = 0 ; l < k ; l ++ )
124
+ if (CC [i * m + j ] != DD [i * m + j ])
125
+ ret ++ ;
126
+ }
93
127
}
94
128
if (ret != 0 )
95
129
fprintf (stderr , "FATAL ERROR SHGEMM - Return code: %d\n" , ret );
0 commit comments