@@ -85,7 +85,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57)
8585KERNEL_FLOAT_DEFINE_POLY (asin_poly, 5 , 0.009796 , -0.03772 , 0.0857 , -0.2142 , 1.57 )
8686
8787#if KERNEL_FLOAT_FP16_AVAILABLE
88- KERNEL_FLOAT_DEVICE __half2 flipsign (__half2 input, __half2 sign) {
88+ KERNEL_FLOAT_DEVICE half2_t flipsign (half2_t input, half2_t sign) {
8989 // Flip signbit of input when sign<0
9090 uint32_t result;
9191
@@ -97,10 +97,10 @@ KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) {
9797 result = uint32_t (transmute<uint32_t >(sign) & 0x80008000 ) ^ transmute<uint32_t >(input);
9898#endif
9999
100- return transmute<__half2 >(result);
100+ return transmute<half2_t >(result);
101101}
102102
103- KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask (__half2 a, __half2 b) {
103+ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask (half2_t a, half2_t b) {
104104 uint32_t val;
105105#if KERNEL_FLOAT_IS_CUDA
106106 uint32_t ai = *(reinterpret_cast <const uint32_t *>(&a));
@@ -112,42 +112,42 @@ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) {
112112 return val;
113113}
114114
115- KERNEL_FLOAT_INLINE __half2 make_half2 (half x) {
115+ KERNEL_FLOAT_INLINE half2_t make_half2 (half x) {
116116 return {x, x};
117117}
118118
119- KERNEL_FLOAT_DEVICE __half2 normalize_trig_input (__half2 x) {
119+ KERNEL_FLOAT_DEVICE half2_t normalize_trig_input (half2_t x) {
120120 /* Using rint is too slow. Round using floating-point magic instead. */
121- // __half2 x = arg * make_half2(-0.15915494309);
121+ // half2_t x = arg * make_half2(-0.15915494309);
122122 // return __hfma2(arg, make_half2(0.15915494309), h2rint(x));
123123
124124 // 1/(2pi) = 0.15915494309189535
125125 static constexpr double ONE_OVER_TWOPI = 0.15915494309189535 ;
126126 static constexpr double OFFSET = -2042.0 ;
127127
128- __half2 ws = __hfma2 (x, make_half2 (-ONE_OVER_TWOPI), make_half2 (-OFFSET)) + make_half2 (OFFSET);
128+ half2_t ws = __hfma2 (x, make_half2 (-ONE_OVER_TWOPI), make_half2 (-OFFSET)) + make_half2 (OFFSET);
129129 return __hfma2 (x, make_half2 (ONE_OVER_TWOPI), ws);
130130}
131131
132132template <int Iter>
133- KERNEL_FLOAT_DEVICE __half2 cos (__half2 x) {
134- __half2 xf = normalize_trig_input (x);
133+ KERNEL_FLOAT_DEVICE half2_t cos (half2_t x) {
134+ half2_t xf = normalize_trig_input (x);
135135 return cos_poly<half, Iter + 1 >::call (__hmul2 (xf, xf));
136136}
137137
138138template <int Iter>
139- KERNEL_FLOAT_DEVICE __half2 sin (__half2 x) {
140- __half2 xf = normalize_trig_input (x);
139+ KERNEL_FLOAT_DEVICE half2_t sin (half2_t x) {
140+ half2_t xf = normalize_trig_input (x);
141141 return sin_poly<half, Iter>::call (__hmul2 (xf, xf)) * xf;
142142}
143143
144144template <int Iter>
145- KERNEL_FLOAT_DEVICE __half2 rcp (__half2 x) {
145+ KERNEL_FLOAT_DEVICE half2_t rcp (half2_t x) {
146146 // Flip bits
147147 uint32_t m = ~transmute<uint32_t >(x);
148148
149149 // Multiply by bias (add contant)
150- __half2 y = transmute<__half2 >(uint32_t (0x776d776d ) + m);
150+ half2_t y = transmute<half2_t >(uint32_t (0x776d776d ) + m);
151151
152152#pragma unroll
153153 for (int i = 0 ; i < Iter; i++) {
@@ -159,40 +159,40 @@ KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) {
159159}
160160
161161template <int Iter>
162- KERNEL_FLOAT_DEVICE __half2 rsqrt (__half2 x) {
162+ KERNEL_FLOAT_DEVICE half2_t rsqrt (half2_t x) {
163163 // Set top and bottom bits for both halfs, then shift by 1, then invert
164164 uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
165165 // uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
166166
167167 // Add bias (0x199c)
168- __half2 y = transmute<__half2 >(uint32_t (r) + uint32_t (0x199c199c ));
168+ half2_t y = transmute<half2_t >(uint32_t (r) + uint32_t (0x199c199c ));
169169
170170 // Newton-Raphson iterations
171171#pragma unroll
172172 for (int i = 0 ; i < Iter; i++) {
173- __half2 half_x = make_half2 (-0.5 ) * x;
174- __half2 correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
173+ half2_t half_x = make_half2 (-0.5 ) * x;
174+ half2_t correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
175175 y = __hfma2 (correction, y, y); // y += y * correction
176176 }
177177
178178 return y;
179179}
180180
181181template <int Iter>
182- KERNEL_FLOAT_DEVICE __half2 sqrt (__half2 x) {
182+ KERNEL_FLOAT_DEVICE half2_t sqrt (half2_t x) {
183183 if (Iter == 1 ) {
184- __half2 y = rsqrt<0 >(x);
184+ half2_t y = rsqrt<0 >(x);
185185
186186 // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)`
187- __half2 xy = x * y;
187+ half2_t xy = x * y;
188188 return xy * __hfma2 (make_half2 (-0.5 ) * y, xy, make_half2 (1.5 ));
189189 }
190190
191191 return x * rsqrt<Iter>(x);
192192}
193193
194194template <int Iter>
195- KERNEL_FLOAT_DEVICE __half2 asin (__half2 x) {
195+ KERNEL_FLOAT_DEVICE half2_t asin (half2_t x) {
196196 static constexpr double HALF_PI = 1.57079632679 ;
197197 auto abs_x = __habs2 (x);
198198 auto v = asin_poly<half, Iter + 1 >::call (abs_x);
@@ -201,36 +201,36 @@ KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) {
201201}
202202
203203template <int Iter>
204- KERNEL_FLOAT_DEVICE __half2 acos (__half2 x) {
204+ KERNEL_FLOAT_DEVICE half2_t acos (half2_t x) {
205205 static constexpr double HALF_PI = 1.57079632679 ;
206206 return make_half2 (HALF_PI) - asin<Iter>(x);
207207}
208208
209209template <int Deg>
210- KERNEL_FLOAT_DEVICE __half2 exp (__half2 x) {
211- __half2 y;
210+ KERNEL_FLOAT_DEVICE half2_t exp (half2_t x) {
211+ half2_t y;
212212
213213 if (Deg == 0 ) {
214214 // Bring the value to range [32, 64]
215215 // 1.442 = 1/log(2)
216216 // 46.969 = 32.5/log(2)
217- __half2 m = __hfma2 (x, make_half2 (1.442 ), make_half2 (46.9375 ));
217+ half2_t m = __hfma2 (x, make_half2 (1.442 ), make_half2 (46.9375 ));
218218
219219 // Transmute to int, shift higher mantissa bits into exponent field.
220- y = transmute<__half2 >((transmute<uint32_t >(m) & 0x03ff03ff ) << 5 );
220+ y = transmute<half2_t >((transmute<uint32_t >(m) & 0x03ff03ff ) << 5 );
221221 } else {
222222 // Add a large number to round to an integer
223- __half2 v = __hfma2 (x, make_half2 (1.442 ), make_half2 (1231.0 ));
223+ half2_t v = __hfma2 (x, make_half2 (1.442 ), make_half2 (1231.0 ));
224224
225225 // The exponent is now in the lower 5 bits. Shift that into the exponent field.
226- __half2 exp = transmute<__half2 >((transmute<uint32_t >(v) & 0x001f001f ) << 10 );
226+ half2_t exp = transmute<half2_t >((transmute<uint32_t >(v) & 0x001f001f ) << 10 );
227227
228228 // The fractional part can be obtained from "1231-v".
229229 // 0.6934 = log(2)
230- __half2 frac = __hfma2 (make_half2 (1231.0 ) - v, make_half2 (0.6934 ), x);
230+ half2_t frac = __hfma2 (make_half2 (1231.0 ) - v, make_half2 (0.6934 ), x);
231231
232232 // This is the Taylor expansion of "exp(x)-1" around 0
233- __half2 adjust;
233+ half2_t adjust;
234234 if (Deg == 1 ) {
235235 adjust = frac;
236236 } else if (Deg == 2 ) {
@@ -250,21 +250,21 @@ KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) {
250250
251251 // Values below -10.39 (= -15*log(2)) become zero
252252 uint32_t zero_mask = half2_gt_mask (x, make_half2 (-10.390625 ));
253- return transmute<__half2 >(zero_mask & transmute<uint32_t >(y));
253+ return transmute<half2_t >(zero_mask & transmute<uint32_t >(y));
254254}
255255
256256template <int = 0 >
257- KERNEL_FLOAT_DEVICE __half2 log (__half2 arg) {
257+ KERNEL_FLOAT_DEVICE half2_t log (half2_t arg) {
258258 // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0)
259259 uint32_t bits = bitwise_if_else (0x03ff03ff , transmute<uint32_t >(arg) >> 5 , 0x50005000 );
260260
261261 // 0.6934 = log(2)
262262 // 32.53 = 46.969*log(2)
263- return __hfma2 (transmute<__half2 >(bits), make_half2 (0.6934 ), make_half2 (-32.53125 ));
263+ return __hfma2 (transmute<half2_t >(bits), make_half2 (0.6934 ), make_half2 (-32.53125 ));
264264}
265265
266266template <int Deg>
267- KERNEL_FLOAT_DEVICE __half2 tanh (__half2 x) {
267+ KERNEL_FLOAT_DEVICE half2_t tanh (half2_t x) {
268268 if (Deg == 0 ) {
269269 return x * rcp<0 >(make_half2 (0.2869 ) + __habs2 (x));
270270 } else {
@@ -278,39 +278,39 @@ KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) {
278278#endif // KERNEL_FLOAT_FP16_AVAILABLE
279279
280280#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
281- KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162 (__bfloat16 x) {
281+ KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162 (bfloat16_t x) {
282282 return {x, x};
283283}
284284
285- KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162 (double x) {
285+ KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162 (double x) {
286286 return {__double2bfloat16 (x), __double2bfloat16 (x)};
287287}
288288
289- KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input (__nv_bfloat162 x) {
289+ KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input (bfloat16x2_t x) {
290290 static constexpr double ONE_OVER_TWOPI = 0.15915494309189535 ;
291291 static constexpr double OFFSET = -2042.0 ;
292292
293- __bfloat162 ws = __hadd2 (
293+ bfloat16x2_t ws = __hadd2 (
294294 __hfma2 (x, make_bfloat162 (-ONE_OVER_TWOPI), make_bfloat162 (-OFFSET)),
295295 make_bfloat162 (OFFSET));
296296 return __hfma2 (x, make_bfloat162 (ONE_OVER_TWOPI), ws);
297297}
298298
299299template <int Iter>
300- KERNEL_FLOAT_DEVICE __bfloat162 cos (__bfloat162 x) {
301- __bfloat162 xf = normalize_trig_input (x);
300+ KERNEL_FLOAT_DEVICE bfloat16x2_t cos (bfloat16x2_t x) {
301+ bfloat16x2_t xf = normalize_trig_input (x);
302302 return cos_poly<__bfloat16, Iter + 1 >::call (__hmul2 (xf, xf));
303303}
304304
305305template <int Iter>
306- KERNEL_FLOAT_DEVICE __bfloat162 sin (__bfloat162 x) {
307- __bfloat162 xf = normalize_trig_input (x);
306+ KERNEL_FLOAT_DEVICE bfloat16x2_t sin (bfloat16x2_t x) {
307+ bfloat16x2_t xf = normalize_trig_input (x);
308308 return __hmul2 (sin_poly<__bfloat16, Iter>::call (__hmul2 (xf, xf)), xf);
309309}
310310
311311template <int Iter>
312- KERNEL_FLOAT_DEVICE __bfloat162 rcp (__bfloat162 x) {
313- __bfloat162 y = transmute<__bfloat162 >(uint32_t (0x7ef07ef0 ) + ~transmute<uint32_t >(x));
312+ KERNEL_FLOAT_DEVICE bfloat16x2_t rcp (bfloat16x2_t x) {
313+ bfloat16x2_t y = transmute<bfloat16x2_t >(uint32_t (0x7ef07ef0 ) + ~transmute<uint32_t >(x));
314314
315315#pragma unroll
316316 for (int i = 0 ; i < Iter; i++) {
@@ -321,36 +321,36 @@ KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) {
321321}
322322
323323template <int Iter>
324- KERNEL_FLOAT_DEVICE __bfloat162 rsqrt (__bfloat162 x) {
324+ KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt (bfloat16x2_t x) {
325325 // Set top and bottom bits for both halfs, then shift by 1, then invert
326326 uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
327327
328328 // Add bias (0x1f36)
329- __bfloat162 y = transmute<__bfloat162 >(uint32_t (r) + uint32_t (0x1f361f36 ));
329+ bfloat16x2_t y = transmute<bfloat16x2_t >(uint32_t (r) + uint32_t (0x1f361f36 ));
330330
331331 // Newton-Raphson iterations
332332#pragma unroll
333333 for (int i = 0 ; i < Iter; i++) {
334- __bfloat162 half_x = __hmul2 (make_bfloat162 (-0.5 ), x);
335- __bfloat162 correction = __hfma2 (half_x, __hmul2 (y, y), make_bfloat162 (0.5 ));
334+ bfloat16x2_t half_x = __hmul2 (make_bfloat162 (-0.5 ), x);
335+ bfloat16x2_t correction = __hfma2 (half_x, __hmul2 (y, y), make_bfloat162 (0.5 ));
336336 y = __hfma2 (correction, y, y); // y += y * correction
337337 }
338338
339339 return y;
340340}
341341
342342template <int Iter>
343- KERNEL_FLOAT_DEVICE __bfloat162 sqrt (__bfloat162 x) {
343+ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt (bfloat16x2_t x) {
344344 return __hmul2 (x, rsqrt<Iter>(x));
345345}
346346
347347template <int = 0 >
348- KERNEL_FLOAT_DEVICE __bfloat162 exp (__bfloat162 arg) {
348+ KERNEL_FLOAT_DEVICE bfloat16x2_t exp (bfloat16x2_t arg) {
349349 static constexpr float SCALE = 1 .44272065994f / 256 .0f ;
350350 static constexpr float OFFSET = 382.4958400542335 ;
351351
352- auto a = fmaf (__bfloat162float (arg.x ), SCALE, OFFSET);
353- auto b = fmaf (__bfloat162float (arg.y ), SCALE, OFFSET);
352+ auto a = fmaf (bfloat16x2_tfloat (arg.x ), SCALE, OFFSET);
353+ auto b = fmaf (bfloat16x2_tfloat (arg.y ), SCALE, OFFSET);
354354
355355 return {
356356 transmute<__bfloat16>(uint16_t (transmute<uint32_t >(a))),
@@ -362,17 +362,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) {
362362#define KERNEL_FLOAT_DEFINE_APPROX_FUN (FULL_NAME, FUN, DEG ) \
363363 namespace detail { \
364364 template <int Degree> \
365- struct apply_impl <approx_level_policy<Degree>, ops::FUN<__half >, 2 , __half, __half > { \
365+ struct apply_impl <approx_level_policy<Degree>, ops::FUN<half_t >, 2 , half_t , half_t > { \
366366 KERNEL_FLOAT_INLINE static void \
367- call (ops::FUN<__half > fun, __half * output, const __half * input) { \
368- __half2 res = approx::FUN<Degree>(__half2 {input[0 ], input[1 ]}); \
367+ call (ops::FUN<half_t > fun, half_t * output, const half_t * input) { \
368+ half2_t res = approx::FUN<Degree>(half2_t {input[0 ], input[1 ]}); \
369369 output[0 ] = res.x ; \
370370 output[1 ] = res.y ; \
371371 } \
372372 }; \
373373 template <> \
374- struct apply_impl <approx_policy, ops::FUN<__half >, 2 , __half, __half >: \
375- apply_impl<approx_level_policy<DEG>, ops::FUN<__half >, 2 , __half, __half > {}; \
374+ struct apply_impl <approx_policy, ops::FUN<half_t >, 2 , half_t , half_t >: \
375+ apply_impl<approx_level_policy<DEG>, ops::FUN<half_t >, 2 , half_t , half_t > {}; \
376376 } \
377377 \
378378 template <int Level = -1 , typename V> \
0 commit comments