From 5911680ba0ca5594c7d3d709572038dae4a082bf Mon Sep 17 00:00:00 2001 From: Gefei Zuo Date: Mon, 22 Sep 2025 12:54:01 -0700 Subject: [PATCH 1/3] Disable SWA related decode kernel tests since they are not supported Summary: D80992628 introduced SWA FWD kernel changes which did not support decode kernels (i.e., supporting sm100_fmha_fwd but not sm100_fmha_gen). Similarly, softmax_scale introduced in D82788784 did not support decode kernels either. In blackwell_fmha_test, the these parameters are dropped during decode kernel selection (https://www.internalfb.com/code/fbsource/[cd7066706035]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py?lines=182) To avoid confusion, do not test test_decode with ignored parameters. Differential Revision: D82991496 --- .../gen_ai/test/attention/blackwell_fmha_test.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 9ede05918e..758332b315 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -439,14 +439,10 @@ def _execute_cutlass_blackwell_attn_varlen( seqlen_k, batch_size, is_mqa, - window_size, - sm_scale, ) for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True] - for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] - for sm_scale in [None, 1.0 / 128] ] ) def test_decode( @@ -454,8 +450,6 @@ def test_decode( seqlen_k: int, batch_size: int, is_mqa: bool, - window_size: tuple[int, int], - sm_scale: Optional[float], q_heads: int = 8, dtype: torch.dtype = torch.float8_e4m3fn, ) -> None: @@ -473,10 +467,12 @@ def test_decode( head_dim=128, dtype=dtype, causal=causal, - window_size=window_size, + # Decode kernel does not support sliding window attention yet + window_size=(-1, -1), fwd_only=True, deterministic=False, - sm_scale=sm_scale, + # Decode kernel does not support sm_scale + sm_scale=None, ) @skip_cuda_lt_sm100 From bfaa65188f8752f6853a0b7d6f2b93d4d734ad48 Mon Sep 17 00:00:00 2001 From: Gefei Zuo Date: Mon, 22 Sep 2025 12:55:14 -0700 Subject: [PATCH 2/3] Support bf16 in blackwell cutlass decode attention kernel Summary: 1. Reduce pipeline stages to avoid exceeding smem limit 2. Add static_assert to make sure smem capacity violation is raised during compilation rather than runtime 3. Select the TMEM intrinsics based on sizeof(Element). 4. Update unittest to include bf16 5. Also label decode kernel test name with their corresponding test parameters. Differential Revision: D82991495 --- .../blackwell_gen_impl.cu | 9 ++++-- .../collective/fmha_common.hpp | 32 +++++++++++++++++++ ...m100_fmha_gen_mainloop_warpspecialized.hpp | 20 ++++++++++-- .../cutlass_blackwell_fmha/device/fmha.hpp | 2 ++ .../test/attention/blackwell_fmha_test.py | 18 +++++++---- 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 71c4603eea..00de27b3d4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -262,10 +262,15 @@ struct GenRunner { }; // Dispatch macros for different element types -// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. #define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ [&] { \ - if (DTYPE == at::kFloat8_e4m3fn) { \ + if (DTYPE == at::kHalf) { \ + using ELEMENT_TYPE = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kBFloat16) { \ + using ELEMENT_TYPE = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kFloat8_e4m3fn) { \ using ELEMENT_TYPE = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ } else { \ diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..708f560368 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -126,3 +126,35 @@ void warpgroup_reg_set() { } } // namespace cutlass::fmha::collective + +namespace constexpr_type_map { +/* + * The following utility type_traits allow mapping constexpr variable to type at + * compile time. + * The default return type defined for each map would be returned if queried key + * does not exist in the map. + */ + +template +struct cValTypePair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct TypeMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename TypeMap::template query + >; +}; + +template +struct TypeMap { + template + using query = std::conditional_t; +}; + +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index f0442e06a8..acc2513bc9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -86,7 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + static constexpr int StageCountKV = StageCountQ * ((sizeof(Element) == 1) ? /*fp8*/ 12 : /*bf16/fp16*/5); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -540,9 +540,23 @@ struct Sm100FmhaGenMainloopWarpspecialized { tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + using TMEM_LOAD_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_LOAD_32dp32b16x> + >; + using TMEM_STORE_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_STORE_32dp32b16x> + >; // Each thread owns a single row - using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_LOAD = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_LOAD_OPMAP::query, + SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_STORE_OPMAP::query, + SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp index d0f4331cea..ccff059f53 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp @@ -39,6 +39,7 @@ // common #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" +#include "cutlass/arch/arch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" @@ -57,6 +58,7 @@ template class FMHA { public: using Kernel = Kernel_; + static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); static int const kThreadCount = Kernel::MaxThreadsPerBlock; diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 758332b315..45d5415aa2 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -439,25 +439,31 @@ def _execute_cutlass_blackwell_attn_varlen( seqlen_k, batch_size, is_mqa, + q_heads, + dtype, ) for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True] - ] + for q_heads in [8] + for dtype in [torch.float8_e4m3fn, torch.bfloat16] + ], + name_func=lambda func, num, p: func.__name__ + + "_" + + num + + "_" + + "_".join(map(str, p)), ) def test_decode( self, seqlen_k: int, batch_size: int, is_mqa: bool, - q_heads: int = 8, - dtype: torch.dtype = torch.float8_e4m3fn, + q_heads: int, + dtype: torch.dtype, ) -> None: seqlen_q = 1 causal = True - assert ( - dtype == torch.float8_e4m3fn - ), "Gen Kernel only supports float8_e4m3fn for now" self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q, From 5589cbb90eabc66a84ddfea664ef3c2c06923ef6 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Mon, 22 Sep 2025 19:43:40 -0700 Subject: [PATCH 3/3] Add cutlass decode kernel to TritonBench (#4853) Summary: X-link: https://github.com/meta-pytorch/tritonbench/pull/376 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4853 X-link: https://github.com/facebookresearch/FBGEMM/pull/1875 Add cutlass blackwell FMHA decode kernel implementation to TritonBench benchmarking suite . Reviewed By: sryap Differential Revision: D80041532 --- .../attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 00de27b3d4..08962f207c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -304,7 +304,7 @@ at::Tensor dispatch_fmha_gen_fwd( return DISPATCH_ELEMENT_TYPE(q.scalar_type(), Element, [&] { return DISPATCH_KERNEL_TYPE(static_cast(kernel_type), KType, [&] { - GenRunner, Shape<_1, _1, _1>> + GenRunner, Shape<_1, _1, _1>> runner; return runner.fmha_fwd(q, k, v, seqlen_kv, batch_idx); });