diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 71c4603eea..00de27b3d4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -262,10 +262,15 @@ struct GenRunner { }; // Dispatch macros for different element types -// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. #define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ [&] { \ - if (DTYPE == at::kFloat8_e4m3fn) { \ + if (DTYPE == at::kHalf) { \ + using ELEMENT_TYPE = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kBFloat16) { \ + using ELEMENT_TYPE = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == at::kFloat8_e4m3fn) { \ using ELEMENT_TYPE = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ } else { \ diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..708f560368 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -126,3 +126,35 @@ void warpgroup_reg_set() { } } // namespace cutlass::fmha::collective + +namespace constexpr_type_map { +/* + * The following utility type_traits allow mapping constexpr variable to type at + * compile time. + * The default return type defined for each map would be returned if queried key + * does not exist in the map. + */ + +template +struct cValTypePair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct TypeMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename TypeMap::template query + >; +}; + +template +struct TypeMap { + template + using query = std::conditional_t; +}; + +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index f0442e06a8..fe477b930b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -86,7 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + static constexpr int StageCountKV = StageCountQ * ((sizeof(Element) == 1) ? /*fp8*/ 11 : /*bf16/fp16*/ 5); using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -540,9 +540,23 @@ struct Sm100FmhaGenMainloopWarpspecialized { tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + using TMEM_LOAD_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_LOAD_32dp32b16x> + >; + using TMEM_STORE_OPMAP = constexpr_type_map::TypeMap, + constexpr_type_map::cValTypePair<2, SM100_TMEM_STORE_32dp32b16x> + >; // Each thread owns a single row - using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_LOAD = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_LOAD_OPMAP::query, + SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t( + TileShapeQK{}) < _128{}, + TMEM_STORE_OPMAP::query, + SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp index d0f4331cea..ccff059f53 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp @@ -39,6 +39,7 @@ // common #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" +#include "cutlass/arch/arch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" @@ -57,6 +58,7 @@ template class FMHA { public: using Kernel = Kernel_; + static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); static int const kThreadCount = Kernel::MaxThreadsPerBlock; diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 399321e521..a06519e590 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -440,25 +440,31 @@ def _execute_cutlass_blackwell_attn_varlen( seqlen_k, batch_size, is_mqa, + q_heads, + dtype, ) for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True] - ] + for q_heads in [8] + for dtype in [torch.float8_e4m3fn, torch.bfloat16] + ], + name_func=lambda func, num, p: func.__name__ + + "_" + + num + + "_" + + "_".join(map(str, p)), ) def test_decode( self, seqlen_k: int, batch_size: int, is_mqa: bool, - q_heads: int = 8, - dtype: torch.dtype = torch.float8_e4m3fn, + q_heads: int, + dtype: torch.dtype, ) -> None: seqlen_q = 1 causal = True - assert ( - dtype == torch.float8_e4m3fn - ), "Gen Kernel only supports float8_e4m3fn for now" self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q,