Skip to content

Commit 7e98d8b

Browse files
authored
perf: add 1x4x1 cluster shape for fp8 bmm M<16 cases (#1473)
1 parent ebcf044 commit 7e98d8b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

csrc/fp8_gemm_cutlass.jinja

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace flashinfer {
2020
namespace gemm {
2121
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM);
2222
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
23+
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM);
2324
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
2425
INSTANCE_FP8_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
2526
} // namespace gemm

include/flashinfer/gemm/fp8_gemm_cutlass_template.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const
8383
_2SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
8484
workspacePtr, workspaceBytes, stream);
8585
break;
86+
case ClusterShape::ClusterShape_1x4x1:
87+
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _4, _1>,
88+
_1SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
89+
workspacePtr, workspaceBytes, stream);
90+
break;
8691
default:
8792
throw std::runtime_error("invalid config for fp8 gemm");
8893
break;
@@ -205,9 +210,8 @@ std::vector<CutlassGemmConfig> CutlassFp8GemmRunner<T>::getConfigs() const {
205210
};
206211

207212
std::vector<ClusterShape> clusterShapes = {
208-
ClusterShape::ClusterShape_1x1x1,
209-
ClusterShape::ClusterShape_1x2x1,
210-
ClusterShape::ClusterShape_2x1x1,
213+
ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1,
214+
ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_2x1x1,
211215
ClusterShape::ClusterShape_2x2x1,
212216
};
213217
for (auto const& tile_config : tilesSm100) {

0 commit comments

Comments
 (0)