@@ -49,66 +49,55 @@ struct zip_bfloat16x2 {
4949
5050template <typename F, size_t N>
5151struct apply_impl <F, N, __nv_bfloat16, __nv_bfloat16> {
52- KERNEL_FLOAT_INLINE static vector_storage<__nv_bfloat16, N>
53- call (F fun, const vector_storage<__nv_bfloat16, N>& input) {
54- vector_storage<__nv_bfloat16, N> result;
55-
52+ KERNEL_FLOAT_INLINE static void call (F fun, __nv_bfloat16* result, const __nv_bfloat16* input) {
5653#pragma unroll
57- for (size_t i = 0 ; i + 2 <= N; i += 2 ) {
58- __nv_bfloat162 a = {input. data ()[ i], input. data ()[ i + 1 ]};
54+ for (size_t i = 0 ; 2 * i + 1 < N; i++ ) {
55+ __nv_bfloat162 a = {input[ 2 * i], input[ 2 * i + 1 ]};
5956 __nv_bfloat162 b = map_bfloat16x2<F>::call (fun, a);
60- result. data ()[ i + 0 ] = b.x ;
61- result. data ()[ i + 1 ] = b.y ;
57+ result[ 2 * i + 0 ] = b.x ;
58+ result[ 2 * i + 1 ] = b.y ;
6259 }
6360
6461 if (N % 2 != 0 ) {
65- result. data () [N - 1 ] = fun (input. data () [N - 1 ]);
62+ result[N - 1 ] = fun (input[N - 1 ]);
6663 }
67-
68- return result;
6964 }
7065};
7166
7267template <typename F, size_t N>
7368struct apply_impl <F, N, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> {
74- KERNEL_FLOAT_INLINE static vector_storage<__nv_bfloat16, N> call (
75- F fun,
76- const vector_storage<__nv_bfloat16, N>& left,
77- const vector_storage<__nv_bfloat16, N>& right) {
78- vector_storage<__nv_bfloat16, N> result;
69+ KERNEL_FLOAT_INLINE static void
70+ call (F fun, __nv_bfloat16* result, const __nv_bfloat16* left, const __nv_bfloat16* right) {
7971#pragma unroll
80- for (size_t i = 0 ; i + 2 <= N; i += 2 ) {
81- __nv_bfloat162 a = {left. data ()[ i], left. data ()[ i + 1 ]};
82- __nv_bfloat162 b = {right. data ()[ i], right. data ()[ i + 1 ]};
72+ for (size_t i = 0 ; 2 * i + 1 < N; i++ ) {
73+ __nv_bfloat162 a = {left[ 2 * i], left[ 2 * i + 1 ]};
74+ __nv_bfloat162 b = {right[ 2 * i], right[ 2 * i + 1 ]};
8375 __nv_bfloat162 c = zip_bfloat16x2<F>::call (fun, a, b);
84- result. data ()[ i + 0 ] = c.x ;
85- result. data ()[ i + 1 ] = c.y ;
76+ result[ 2 * i + 0 ] = c.x ;
77+ result[ 2 * i + 1 ] = c.y ;
8678 }
8779
8880 if (N % 2 != 0 ) {
89- result. data () [N - 1 ] = fun (left. data () [N - 1 ], right. data () [N - 1 ]);
81+ result[N - 1 ] = fun (left[N - 1 ], right[N - 1 ]);
9082 }
91-
92- return result;
9383 }
9484};
9585
9686template <typename F, size_t N>
9787struct reduce_impl <F, N, __nv_bfloat16, enable_if_t <(N >= 2 )>> {
98- KERNEL_FLOAT_INLINE static __nv_bfloat16
99- call (F fun, const vector_storage<__nv_bfloat16, N>& input) {
100- __nv_bfloat162 accum = {input.data ()[0 ], input.data ()[1 ]};
88+ KERNEL_FLOAT_INLINE static __nv_bfloat16 call (F fun, const __nv_bfloat16* input) {
89+ __nv_bfloat162 accum = {input[0 ], input[1 ]};
10190
10291#pragma unroll
103- for (size_t i = 2 ; i + 2 <= N; i += 2 ) {
104- __nv_bfloat162 a = {input. data ()[ i], input. data ()[ i + 1 ]};
92+ for (size_t i = 0 ; 2 * i + 1 < N; i++ ) {
93+ __nv_bfloat162 a = {input[ 2 * i], input[ 2 * i + 1 ]};
10594 accum = zip_bfloat16x2<F>::call (fun, accum, a);
10695 }
10796
10897 __nv_bfloat16 result = fun (accum.x , accum.y );
10998
11099 if (N % 2 != 0 ) {
111- result = fun (result, input. data () [N - 1 ]);
100+ result = fun (result, input[N - 1 ]);
112101 }
113102
114103 return result;
@@ -126,6 +115,7 @@ struct reduce_impl<F, N, __nv_bfloat16, enable_if_t<(N >= 2)>> {
126115 }; \
127116 }
128117
118+ // There operations are not implemented in half precision, so they are forward to single precision
129119KERNEL_FLOAT_BF16_UNARY_FORWARD (tan)
130120KERNEL_FLOAT_BF16_UNARY_FORWARD (asin)
131121KERNEL_FLOAT_BF16_UNARY_FORWARD (acos)
@@ -243,32 +233,22 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2)
243233KERNEL_FLOAT_BF16_CAST (double , __double2bfloat16(input), double (__bfloat162float(input)));
244234KERNEL_FLOAT_BF16_CAST (float , __float2bfloat16(input), __bfloat162float(input));
245235
236+ // clang-format off
246237// there are no official char casts. Instead, cast to int and then to char
247238KERNEL_FLOAT_BF16_CAST (char , __int2bfloat16_rn(input), (char )__bfloat162int_rz(input));
248- KERNEL_FLOAT_BF16_CAST (
249- signed char ,
250- __int2bfloat16_rn (input),
251- (signed char )__bfloat162int_rz(input));
252- KERNEL_FLOAT_BF16_CAST (
253- unsigned char ,
254- __int2bfloat16_rn (input),
255- (unsigned char )__bfloat162int_rz(input));
239+ KERNEL_FLOAT_BF16_CAST (signed char , __int2bfloat16_rn(input), (signed char )__bfloat162int_rz(input));
240+ KERNEL_FLOAT_BF16_CAST (unsigned char , __int2bfloat16_rn(input), (unsigned char )__bfloat162int_rz(input));
256241
257242KERNEL_FLOAT_BF16_CAST (signed short , __bfloat162short_rz(input), __short2bfloat16_rn(input));
258243KERNEL_FLOAT_BF16_CAST (signed int , __bfloat162int_rz(input), __int2bfloat16_rn(input));
259- KERNEL_FLOAT_BF16_CAST (
260- signed long ,
261- __ll2bfloat16_rn (input),
262- (signed long )(__bfloat162ll_rz(input)));
244+ KERNEL_FLOAT_BF16_CAST (signed long , __ll2bfloat16_rn(input), (signed long )(__bfloat162ll_rz(input)));
263245KERNEL_FLOAT_BF16_CAST (signed long long , __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
264246
265247KERNEL_FLOAT_BF16_CAST (unsigned short , __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input));
266248KERNEL_FLOAT_BF16_CAST (unsigned int , __bfloat162uint_rz(input), __uint2bfloat16_rn(input));
267- KERNEL_FLOAT_BF16_CAST (
268- unsigned long ,
269- __ull2bfloat16_rn (input),
270- (unsigned long )(__bfloat162ull_rz(input)));
249+ KERNEL_FLOAT_BF16_CAST (unsigned long , __ull2bfloat16_rn(input), (unsigned long )(__bfloat162ull_rz(input)));
271250KERNEL_FLOAT_BF16_CAST (unsigned long long , __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
251+ // clang-format on
272252
273253using bfloat16 = __nv_bfloat16;
274254// KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16)
0 commit comments