Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,15 @@ struct GenRunner {
};

// Dispatch macros for different element types
// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types.
#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \
[&] { \
if (DTYPE == at::kFloat8_e4m3fn) { \
if (DTYPE == at::kHalf) { \
using ELEMENT_TYPE = cutlass::half_t; \
return __VA_ARGS__(); \
} else if (DTYPE == at::kBFloat16) { \
using ELEMENT_TYPE = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} else if (DTYPE == at::kFloat8_e4m3fn) { \
using ELEMENT_TYPE = cutlass::float_e4m3_t; \
return __VA_ARGS__(); \
} else { \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,35 @@ void warpgroup_reg_set() {
}

} // namespace cutlass::fmha::collective

namespace constexpr_type_map {
/*
* The following utility type_traits allow mapping constexpr variable to type at
* compile time.
* The default return type defined for each map would be returned if queried key
* does not exist in the map.
*/

template <auto keyVal, typename _valueT>
struct cValTypePair {
static constexpr auto key = keyVal;
using valueT = _valueT;
};

template <typename Default, typename FirstMapping, typename ...OtherMapping>
struct TypeMap {
template<auto QueryKey>
using query = std::conditional_t<
QueryKey == FirstMapping::key,
typename FirstMapping::valueT,
typename TypeMap<Default, OtherMapping...>::template query<QueryKey>
>;
};

template <typename Default, typename LastMapping>
struct TypeMap<Default, LastMapping> {
template<auto QueryKey>
using query = std::conditional_t<QueryKey == LastMapping::key, typename LastMapping::valueT, Default>;
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct Sm100FmhaGenMainloopWarpspecialized {
using Mask = Mask_;

static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2;
static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{});
static constexpr int StageCountKV = StageCountQ * ((sizeof(Element) == 1) ? /*fp8*/ 11 : /*bf16/fp16*/ 5);

using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
Expand Down Expand Up @@ -540,9 +540,23 @@ struct Sm100FmhaGenMainloopWarpspecialized {
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));

using TMEM_LOAD_OPMAP = constexpr_type_map::TypeMap<void,
constexpr_type_map::cValTypePair<1, SM100_TMEM_LOAD_32dp32b8x>,
constexpr_type_map::cValTypePair<2, SM100_TMEM_LOAD_32dp32b16x>
>;
using TMEM_STORE_OPMAP = constexpr_type_map::TypeMap<void,
constexpr_type_map::cValTypePair<1, SM100_TMEM_STORE_32dp32b8x>,
constexpr_type_map::cValTypePair<2, SM100_TMEM_STORE_32dp32b16x>
>;
// Each thread owns a single row
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_LOAD = conditional_t<size<1>(
TileShapeQK{}) < _128{},
TMEM_LOAD_OPMAP::query<sizeof(Element)>,
SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE = conditional_t<size<1>(
TileShapeQK{}) < _128{},
TMEM_STORE_OPMAP::query<sizeof(Element)>,
SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem

int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/arch/arch.h"

#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
Expand All @@ -57,6 +58,7 @@ template <class Kernel_>
class FMHA {
public:
using Kernel = Kernel_;
static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity.");

static int const kThreadCount = Kernel::MaxThreadsPerBlock;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,25 +440,31 @@ def _execute_cutlass_blackwell_attn_varlen(
seqlen_k,
batch_size,
is_mqa,
q_heads,
dtype,
)
for seqlen_k in [64, 128, 256, 1024]
for batch_size in [1, 2]
for is_mqa in [True]
]
for q_heads in [8]
for dtype in [torch.float8_e4m3fn, torch.bfloat16]
],
name_func=lambda func, num, p: func.__name__
+ "_"
+ num
+ "_"
+ "_".join(map(str, p)),
)
def test_decode(
self,
seqlen_k: int,
batch_size: int,
is_mqa: bool,
q_heads: int = 8,
dtype: torch.dtype = torch.float8_e4m3fn,
q_heads: int,
dtype: torch.dtype,
) -> None:
seqlen_q = 1
causal = True
assert (
dtype == torch.float8_e4m3fn
), "Gen Kernel only supports float8_e4m3fn for now"
self._execute_cutlass_blackwell_attn_dense(
batch_size,
seqlen_q,
Expand Down
Loading