Skip to content

Commit c61ea87

Browse files
committed
Add support for 8-bit floats
1 parent 46d598c commit c61ea87

File tree

4 files changed

+102
-6
lines changed

4 files changed

+102
-6
lines changed

include/kernel_float/fp8.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef KERNEL_FLOAT_FP8_H
2+
#define KERNEL_FLOAT_FP8_H
3+
4+
#include "macros.h"
5+
6+
#if KERNEL_FLOAT_FP8_AVAILABLE
7+
#include <cuda_fp8.h>
8+
9+
#include "vector.h"
10+
11+
namespace kernel_float {
12+
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e4m3)
13+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e4m3)
14+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e4m3)
15+
16+
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e5m2)
17+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e5m2)
18+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e5m2)
19+
} // namespace kernel_float
20+
21+
#if KERNEL_FLOAT_FP16_AVAILABLE
22+
#include "fp16.h"
23+
24+
namespace kernel_float {
25+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
26+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
27+
} // namespace kernel_float
28+
#endif // KERNEL_FLOAT_FP16_AVAILABLE
29+
30+
#if KERNEL_FLOAT_BF16_AVAILABLE
31+
#include "bf16.h"
32+
33+
namespace kernel_float {
34+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
35+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
36+
} // namespace kernel_float
37+
#endif // KERNEL_FLOAT_BF16_AVAILABLE
38+
39+
#endif // KERNEL_FLOAT_FP8_AVAILABLE
40+
#endif // KERNEL_FLOAT_FP8_H

include/kernel_float/prelude.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "bf16.h"
55
#include "constant.h"
66
#include "fp16.h"
7+
#include "fp8.h"
78
#include "vector.h"
89

910
namespace kernel_float {
@@ -66,8 +67,14 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
6667
#endif
6768

6869
#if KERNEL_FLOAT_BF16_AVAILABLE
69-
KERNEL_FLOAT_TYPE_ALIAS(bfloat16, __nv_bfloat16)
70-
KERNEL_FLOAT_TYPE_ALIAS(bf16, __nv_bfloat16)
70+
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16)
71+
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16)
72+
#endif
73+
74+
#if KERNEL_FLOAT_BF8_AVAILABLE
75+
KERNEL_FLOAT_TYPE_ALIAS(float8x, __nv_fp8_e4m3)
76+
KERNEL_FLOAT_TYPE_ALIAS(float8_e4m3x, __nv_fp8_e4m3)
77+
KERNEL_FLOAT_TYPE_ALIAS(float8_e5m2x, __nv_fp8_e5m2)
7178
#endif
7279

7380
template<size_t N>

single_include/kernel_float.h

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2023-09-21 10:00:11.122069
20-
// git hash: 227f987d3fc10499e680bb68f00e1c579afeda97
19+
// date: 2023-09-28 09:58:58.074478
20+
// git hash: 46d598cbca2b9e15abe91848fdcb417d69f0820a
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -3785,6 +3785,46 @@ struct promote_type<__half, __nv_bfloat16> {
37853785
#endif
37863786

37873787
#endif //KERNEL_FLOAT_BF16_H
3788+
#ifndef KERNEL_FLOAT_FP8_H
3789+
#define KERNEL_FLOAT_FP8_H
3790+
3791+
3792+
3793+
#if KERNEL_FLOAT_FP8_AVAILABLE
3794+
#include <cuda_fp8.h>
3795+
3796+
3797+
3798+
namespace kernel_float {
3799+
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e4m3)
3800+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e4m3)
3801+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e4m3)
3802+
3803+
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e5m2)
3804+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e5m2)
3805+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e5m2)
3806+
} // namespace kernel_float
3807+
3808+
#if KERNEL_FLOAT_FP16_AVAILABLE
3809+
3810+
3811+
namespace kernel_float {
3812+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
3813+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
3814+
} // namespace kernel_float
3815+
#endif // KERNEL_FLOAT_FP16_AVAILABLE
3816+
3817+
#if KERNEL_FLOAT_BF16_AVAILABLE
3818+
3819+
3820+
namespace kernel_float {
3821+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
3822+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
3823+
} // namespace kernel_float
3824+
#endif // KERNEL_FLOAT_BF16_AVAILABLE
3825+
3826+
#endif // KERNEL_FLOAT_FP8_AVAILABLE
3827+
#endif // KERNEL_FLOAT_FP8_H
37883828
#ifndef KERNEL_FLOAT_PRELUDE_H
37893829
#define KERNEL_FLOAT_PRELUDE_H
37903830

@@ -3793,6 +3833,7 @@ struct promote_type<__half, __nv_bfloat16> {
37933833

37943834

37953835

3836+
37963837
namespace kernel_float {
37973838
namespace prelude {
37983839
namespace kf = ::kernel_float;
@@ -3853,8 +3894,14 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
38533894
#endif
38543895

38553896
#if KERNEL_FLOAT_BF16_AVAILABLE
3856-
KERNEL_FLOAT_TYPE_ALIAS(bfloat16, __nv_bfloat16)
3857-
KERNEL_FLOAT_TYPE_ALIAS(bf16, __nv_bfloat16)
3897+
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16)
3898+
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16)
3899+
#endif
3900+
3901+
#if KERNEL_FLOAT_BF8_AVAILABLE
3902+
KERNEL_FLOAT_TYPE_ALIAS(float8x, __nv_fp8_e4m3)
3903+
KERNEL_FLOAT_TYPE_ALIAS(float8_e4m3x, __nv_fp8_e4m3)
3904+
KERNEL_FLOAT_TYPE_ALIAS(float8_e5m2x, __nv_fp8_e5m2)
38583905
#endif
38593906

38603907
template<size_t N>

tests/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <cstdint>
77

88
#include "catch2/catch_all.hpp"
9+
10+
#define KERNEL_FLOAT_FP8_AVAILABLE (1)
911
#include "kernel_float.h"
1012

1113
namespace kf = kernel_float;

0 commit comments

Comments
 (0)