diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 00de27b3d4..71c4603eea 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -262,15 +262,10 @@ struct GenRunner { }; // Dispatch macros for different element types +// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. #define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ [&] { \ - if (DTYPE == at::kHalf) { \ - using ELEMENT_TYPE = cutlass::half_t; \ - return __VA_ARGS__(); \ - } else if (DTYPE == at::kBFloat16) { \ - using ELEMENT_TYPE = cutlass::bfloat16_t; \ - return __VA_ARGS__(); \ - } else if (DTYPE == at::kFloat8_e4m3fn) { \ + if (DTYPE == at::kFloat8_e4m3fn) { \ using ELEMENT_TYPE = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ } else { \ diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index d841a8cf90..f0442e06a8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -535,14 +535,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - auto tilePlikeFP32 = get<1>(TileShapeQK{}) / Int{} * Int{}; + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp index 5e7f480cfb..a1c6d627be 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -366,7 +366,7 @@ struct Sm100FmhaGenKernelWarpspecialized { pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; } pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; - pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); typename CollectiveMainloop::PipelineE pipeline_corr_epi( shared_storage.pipelines.corr_epi, pipeline_corr_epi_params,