Skip to content

Commit 0c09cb2

Browse files
Support bf16 in blackwell cutlass decode attention kernel
Summary: 1. Reduce pipeline stages to avoid exceeding smem limit 2. Add static_assert to make sure smem capacity violation is raised during compilation rather than runtime 3. Select the TMEM intrinsics based on sizeof(Element). 4. Update unittest to include bf16 5. Also label decode kernel test name with their corresponding test parameters. Differential Revision: D82991495
1 parent 5911680 commit 0c09cb2

File tree

5 files changed

+70
-11
lines changed

5 files changed

+70
-11
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,15 @@ struct GenRunner {
262262
};
263263

264264
// Dispatch macros for different element types
265-
// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types.
266265
#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \
267266
[&] { \
268-
if (DTYPE == at::kFloat8_e4m3fn) { \
267+
if (DTYPE == at::kHalf) { \
268+
using ELEMENT_TYPE = cutlass::half_t; \
269+
return __VA_ARGS__(); \
270+
} else if (DTYPE == at::kBFloat16) { \
271+
using ELEMENT_TYPE = cutlass::bfloat16_t; \
272+
return __VA_ARGS__(); \
273+
} else if (DTYPE == at::kFloat8_e4m3fn) { \
269274
using ELEMENT_TYPE = cutlass::float_e4m3_t; \
270275
return __VA_ARGS__(); \
271276
} else { \

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,35 @@ void warpgroup_reg_set() {
126126
}
127127

128128
} // namespace cutlass::fmha::collective
129+
130+
namespace constexpr_type_map {
131+
/*
132+
* The following utility type_traits allow mapping constexpr variable to type at
133+
* compile time.
134+
* The default return type defined for each map would be returned if queried key
135+
* does not exist in the map.
136+
*/
137+
138+
template <auto keyVal, typename _valueT>
139+
struct cValTypePair {
140+
static constexpr auto key = keyVal;
141+
using valueT = _valueT;
142+
};
143+
144+
template <typename Default, typename FirstMapping, typename ...OtherMapping>
145+
struct TypeMap {
146+
template<auto QueryKey>
147+
using query = std::conditional_t<
148+
QueryKey == FirstMapping::key,
149+
typename FirstMapping::valueT,
150+
typename TypeMap<Default, OtherMapping...>::template query<QueryKey>
151+
>;
152+
};
153+
154+
template <typename Default, typename LastMapping>
155+
struct TypeMap<Default, LastMapping> {
156+
template<auto QueryKey>
157+
using query = std::conditional_t<QueryKey == LastMapping::key, typename LastMapping::valueT, Default>;
158+
};
159+
160+
}

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
8686
using Mask = Mask_;
8787

8888
static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2;
89-
static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{});
89+
static constexpr int StageCountKV = StageCountQ * ((sizeof(Element) == 1) ? /*fp8*/ 12 : /*bf16/fp16*/5);
9090

9191
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
9292
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
@@ -540,9 +540,23 @@ struct Sm100FmhaGenMainloopWarpspecialized {
540540
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
541541
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
542542

543+
using TMEM_LOAD_OPMAP = constexpr_type_map::TypeMap<void,
544+
constexpr_type_map::cValTypePair<1, SM100_TMEM_LOAD_32dp32b8x>,
545+
constexpr_type_map::cValTypePair<2, SM100_TMEM_LOAD_32dp32b16x>
546+
>;
547+
using TMEM_STORE_OPMAP = constexpr_type_map::TypeMap<void,
548+
constexpr_type_map::cValTypePair<1, SM100_TMEM_STORE_32dp32b8x>,
549+
constexpr_type_map::cValTypePair<2, SM100_TMEM_STORE_32dp32b16x>
550+
>;
543551
// Each thread owns a single row
544-
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
545-
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
552+
using TMEM_LOAD = conditional_t<size<1>(
553+
TileShapeQK{}) < _128{},
554+
TMEM_LOAD_OPMAP::query<sizeof(Element)>,
555+
SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
556+
using TMEM_STORE = conditional_t<size<1>(
557+
TileShapeQK{}) < _128{},
558+
TMEM_STORE_OPMAP::query<sizeof(Element)>,
559+
SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
546560
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
547561

548562
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
// common
4040
#include "cutlass/cutlass.h"
4141
#include "cutlass/device_kernel.h"
42+
#include "cutlass/arch/arch.h"
4243

4344
#if !defined(__CUDACC_RTC__)
4445
#include "cutlass/cluster_launch.hpp"
@@ -57,6 +58,7 @@ template <class Kernel_>
5758
class FMHA {
5859
public:
5960
using Kernel = Kernel_;
61+
static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity.");
6062

6163
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
6264

fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,25 +439,31 @@ def _execute_cutlass_blackwell_attn_varlen(
439439
seqlen_k,
440440
batch_size,
441441
is_mqa,
442+
q_heads,
443+
dtype,
442444
)
443445
for seqlen_k in [64, 128, 256, 1024]
444446
for batch_size in [1, 2]
445447
for is_mqa in [True]
446-
]
448+
for q_heads in [8]
449+
for dtype in [torch.float8_e4m3fn, torch.bfloat16]
450+
],
451+
name_func=lambda func, num, p: func.__name__
452+
+ "_"
453+
+ num
454+
+ "_"
455+
+ "_".join(map(str, p)),
447456
)
448457
def test_decode(
449458
self,
450459
seqlen_k: int,
451460
batch_size: int,
452461
is_mqa: bool,
453-
q_heads: int = 8,
454-
dtype: torch.dtype = torch.float8_e4m3fn,
462+
q_heads: int,
463+
dtype: torch.dtype,
455464
) -> None:
456465
seqlen_q = 1
457466
causal = True
458-
assert (
459-
dtype == torch.float8_e4m3fn
460-
), "Gen Kernel only supports float8_e4m3fn for now"
461467
self._execute_cutlass_blackwell_attn_dense(
462468
batch_size,
463469
seqlen_q,

0 commit comments

Comments
 (0)