From 0887844f15928ef8facb7fe688ba0918106446b7 Mon Sep 17 00:00:00 2001 From: Gefei Zuo Date: Wed, 24 Sep 2025 14:51:33 -0700 Subject: [PATCH] Support bf16 in blackwell cutlass decode attention kernel (#4916) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1940 1. Reduce pipeline stages to avoid exceeding smem limit 2. Add static_assert to make sure smem capacity violation is raised during compilation rather than runtime 3. Select the TMEM intrinsics based on sizeof(Element). 4. Update unittest to include bf16 5. Also label decode kernel test name with their corresponding test parameters. Differential Revision: D82991495 --- .../blackwell_gen_impl.cu | 9 ++++-- .../collective/fmha_common.hpp | 32 +++++++++++++++++++ ...m100_fmha_gen_mainloop_warpspecialized.hpp | 20 ++++++++++-- .../cutlass_blackwell_fmha/device/fmha.hpp | 2 ++ .../test/attention/blackwell_fmha_test.py | 18 +++++++---- 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 71c4603eea..00de27b3d4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -262,10 +262,15 @@ struct GenRunner { }; // Dispatch macros for different element types -// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. #define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ [&] { \ - if (DTYPE == at::kFloat8_e4m3fn) { \ + if (DTYPE == at::kHalf) { \ + using ELEMENT_TYPE = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kBFloat16) { \ + using ELEMENT_TYPE = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kFloat8_e4m3fn) { \ using ELEMENT_TYPE = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ } else { \ diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..708f560368 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -126,3 +126,35 @@ void warpgroup_reg_set() { } } // namespace cutlass::fmha::collective + +namespace constexpr_type_map { +/* + * The following utility type_traits allow mapping constexpr variable to type at + * compile time. + * The default return type defined for each map would be returned if queried key + * does not exist in the map. + */ + +template +struct cValTypePair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct TypeMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename TypeMap::template query + >; +}; + +template +struct TypeMap { + template + using query = std::conditional_t; +}; + +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index f0442e06a8..fe477b930b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -86,7 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + static constexpr int StageCountKV = StageCountQ * ((sizeof(Element) == 1) ? /*fp8*/ 11 : /*bf16/fp16*/ 5); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -540,9 +540,23 @@ struct Sm100FmhaGenMainloopWarpspecialized { tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + using TMEM_LOAD_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_LOAD_32dp32b16x> + >; + using TMEM_STORE_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_STORE_32dp32b16x> + >; // Each thread owns a single row - using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_LOAD = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_LOAD_OPMAP::query, + SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_STORE_OPMAP::query, + SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp index d0f4331cea..ccff059f53 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp @@ -39,6 +39,7 @@ // common #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" +#include "cutlass/arch/arch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" @@ -57,6 +58,7 @@ template class FMHA { public: using Kernel = Kernel_; + static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); static int const kThreadCount = Kernel::MaxThreadsPerBlock; diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 399321e521..a06519e590 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -440,25 +440,31 @@ def _execute_cutlass_blackwell_attn_varlen( seqlen_k, batch_size, is_mqa, + q_heads, + dtype, ) for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True] - ] + for q_heads in [8] + for dtype in [torch.float8_e4m3fn, torch.bfloat16] + ], + name_func=lambda func, num, p: func.__name__ + + "_" + + num + + "_" + + "_".join(map(str, p)), ) def test_decode( self, seqlen_k: int, batch_size: int, is_mqa: bool, - q_heads: int = 8, - dtype: torch.dtype = torch.float8_e4m3fn, + q_heads: int, + dtype: torch.dtype, ) -> None: seqlen_q = 1 causal = True - assert ( - dtype == torch.float8_e4m3fn - ), "Gen Kernel only supports float8_e4m3fn for now" self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q,