1616
1717// ================================================================================
1818// this file has been auto-generated, do not modify its contents!
19- // date: 2024-03-18 16:06:55.100306
20- // git hash: 06e08f55399e148d96070afd0ac36dd414045f04
19+ // date: 2024-04-22 13:28:09.684538
20+ // git hash: fd4eadfbb0c8597276a6c12f972038cd1baff985
2121// ================================================================================
2222
2323#ifndef KERNEL_FLOAT_MACROS_H
@@ -2705,7 +2705,7 @@ struct vector_ref<T, N, const U, Align> {
27052705
27062706#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP (OP, OP_ASSIGN ) \
27072707 template <typename T, size_t N, typename U, size_t Align, typename V> \
2708- KERNEL_FLOAT_INLINE vector_ref<T, N> operator OP_ASSIGN ( \
2708+ KERNEL_FLOAT_INLINE vector_ref<T, N, U, Align > operator OP_ASSIGN ( \
27092709 vector_ref<T, N, U, Align> ptr, \
27102710 const V& value) { \
27112711 ptr.write (ptr.read () OP value); \
@@ -3379,6 +3379,7 @@ namespace kernel_float {
33793379 */
33803380template <typename T, typename E, class S >
33813381struct vector : public S {
3382+ using self_type = vector<T, E, S>;
33823383 using value_type = T;
33833384 using extent_type = E;
33843385 using storage_type = S;
@@ -3577,8 +3578,8 @@ struct vector: public S {
35773578 * vec<float, 4> vec2 = select(input, indices); // [0, 40, 40, 20]
35783579 * ```
35793580 */
3580- template <typename V, typename ... Is>
3581- KERNEL_FLOAT_INLINE select_type<V , Is...> select (const Is&... indices) {
3581+ template <typename ... Is>
3582+ KERNEL_FLOAT_INLINE select_type<self_type , Is...> select (const Is&... indices) {
35823583 return kernel_float::select (*this , indices...);
35833584 }
35843585
@@ -4255,6 +4256,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
42554256 static constexpr bool value = true ;
42564257};
42574258} // namespace detail
4259+ } // namespace kernel_float
42584260
42594261#define KERNEL_FLOAT_FP8_CAST (T ) \
42604262 namespace ops { \
@@ -4287,6 +4289,29 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
42874289 }; \
42884290 }
42894291
4292+ #define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
4293+ namespace detail { \
4294+ template <> \
4295+ struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
4296+ KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
4297+ __half2_raw x; \
4298+ memcpy (&x, v, 2 * sizeof (T)); \
4299+ __nv_fp8x2_storage_t y = __nv_cvt_halfraw2_to_fp8x2 (x, __NV_NOSAT, FP8_INTERP); \
4300+ memcpy (result, &y, 2 * sizeof (FP8_TY)); \
4301+ } \
4302+ }; \
4303+ template <> \
4304+ struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
4305+ KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
4306+ __nv_fp8x2_storage_t x; \
4307+ memcpy (&x, v, 2 * sizeof (FP8_TY)); \
4308+ __half2_raw y = __nv_cvt_fp8x2_to_halfraw2 (x, FP8_INTERP); \
4309+ memcpy (result, &y, 2 * sizeof (T)); \
4310+ } \
4311+ }; \
4312+ }
4313+
4314+ namespace kernel_float {
42904315KERNEL_FLOAT_FP8_CAST (double )
42914316} // namespace kernel_float
42924317
@@ -4297,6 +4322,10 @@ namespace kernel_float {
42974322KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e4m3)
42984323KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half, __nv_fp8_e5m2)
42994324KERNEL_FLOAT_FP8_CAST (__half)
4325+
4326+ KERNEL_FLOAT_FP8_CAST2 (__half, __nv_fp8_e4m3, __NV_E4M3)
4327+ KERNEL_FLOAT_FP8_CAST2 (__half, __nv_fp8_e5m2, __NV_E5M2)
4328+
43004329} // namespace kernel_float
43014330#endif // KERNEL_FLOAT_FP16_AVAILABLE
43024331
@@ -4307,6 +4336,9 @@ namespace kernel_float {
43074336KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e4m3)
43084337KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16, __nv_fp8_e5m2)
43094338KERNEL_FLOAT_FP8_CAST (__nv_bfloat16)
4339+
4340+ KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
4341+ KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
43104342} // namespace kernel_float
43114343#endif // KERNEL_FLOAT_BF16_AVAILABLE
43124344
0 commit comments