Skip to content

Commit 4a0b190

Browse files
djmmossjiahancmgoin
authored andcommitted
[feat]: add SM100 support for cutlass FP8 groupGEMM (vllm-project#20447)
Signed-off-by: Duncan Moss <[email protected]> Signed-off-by: jiahanc <[email protected]> Co-authored-by: jiahanc <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 3c47ab0 commit 4a0b190

File tree

8 files changed

+255
-32
lines changed

8 files changed

+255
-32
lines changed

CMakeLists.txt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
577577
# if it's possible to compile MoE kernels that use its output.
578578
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
579579
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
580-
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
580+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
581581
set_gencode_flags_for_srcs(
582582
SRCS "${SRCS}"
583583
CUDA_ARCHS "${SCALED_MM_ARCHS}")
@@ -595,6 +595,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
595595
endif()
596596
endif()
597597

598+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
599+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
600+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
601+
set_gencode_flags_for_srcs(
602+
SRCS "${SRCS}"
603+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
604+
list(APPEND VLLM_EXT_SRC "${SRCS}")
605+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
606+
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
607+
else()
608+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
609+
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
610+
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
611+
"if you intend on running FP8 quantized MoE models on Blackwell.")
612+
else()
613+
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
614+
"in CUDA target architectures.")
615+
endif()
616+
endif()
617+
598618
# moe_data.cu is used by all CUTLASS MoE kernels.
599619
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
600620
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ using ProblemShape =
1818
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
1919

2020
using ElementAccumulator = float;
21-
using ArchTag = cutlass::arch::Sm90;
2221
using OperatorClass = cutlass::arch::OpClassTensorOp;
2322

2423
using LayoutA = cutlass::layout::RowMajor;
@@ -33,7 +32,7 @@ using LayoutD_Transpose =
3332
using LayoutC = LayoutD;
3433
using LayoutC_Transpose = LayoutD_Transpose;
3534

36-
template <typename ElementAB_, typename ElementC_,
35+
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
3736
template <typename, typename, typename> typename Epilogue_,
3837
typename TileShape, typename ClusterShape, typename KernelSchedule,
3938
typename EpilogueSchedule, bool swap_ab_ = false>
@@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
4342
using ElementC = void;
4443
using ElementD = ElementC_;
4544
using ElementAccumulator = float;
45+
using ArchTag = ArchTag_;
4646

4747
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
4848

@@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
7777
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
7878
Stages, KernelSchedule>::CollectiveOp>;
7979

80-
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
80+
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8181
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
8282

8383
struct GemmKernel : public KernelType {};
@@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
156156
static_cast<ElementD**>(out_ptrs.data_ptr()),
157157
static_cast<StrideC*>(c_strides.data_ptr())};
158158

159+
int device_id = a_tensors.device().index();
160+
static const cutlass::KernelHardwareInfo hw_info{
161+
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
162+
device_id)};
163+
159164
typename GemmKernel::Arguments args{
160165
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
161-
epilogue_args};
166+
epilogue_args, hw_info};
162167

163168
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
164169
GemmOp gemm_op;
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#include <cudaTypedefs.h>
2+
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <torch/all.h>
5+
6+
#include "cutlass/cutlass.h"
7+
#include "grouped_mm_c3x.cuh"
8+
9+
using namespace cute;
10+
11+
namespace {
12+
13+
template <typename InType, typename OutType,
14+
template <typename, typename, typename> typename Epilogue>
15+
struct sm100_fp8_config_default {
16+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
17+
using KernelSchedule =
18+
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
19+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
20+
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
21+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
22+
using ArchTag = cutlass::arch::Sm100;
23+
24+
using Cutlass3xGemm =
25+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
26+
ClusterShape, KernelSchedule, EpilogueSchedule>;
27+
};
28+
29+
template <typename InType, typename OutType,
30+
template <typename, typename, typename> typename Epilogue>
31+
struct sm100_fp8_config_M64 {
32+
// M in [1,64]
33+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
34+
using KernelSchedule =
35+
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
36+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
37+
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
38+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
39+
using ArchTag = cutlass::arch::Sm100;
40+
41+
using Cutlass3xGemm =
42+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
43+
ClusterShape, KernelSchedule, EpilogueSchedule,
44+
true>;
45+
};
46+
47+
template <typename InType, typename OutType,
48+
template <typename, typename, typename> typename Epilogue>
49+
struct sm100_fp8_config_N8192 {
50+
// N in [8192, inf)
51+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
52+
using KernelSchedule =
53+
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
54+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
55+
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
56+
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
57+
using ArchTag = cutlass::arch::Sm100;
58+
59+
using Cutlass3xGemm =
60+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
61+
ClusterShape, KernelSchedule, EpilogueSchedule>;
62+
};
63+
64+
template <typename InType, typename OutType>
65+
void run_cutlass_moe_mm_sm100(
66+
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
67+
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
68+
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
69+
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
70+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
71+
bool per_act_token, bool per_out_ch) {
72+
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
73+
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
74+
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
75+
76+
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
77+
"A tensors must be of type float8_e4m3fn.");
78+
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
79+
"B tensors must be of type float8_e4m3fn.");
80+
81+
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
82+
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
83+
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
84+
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
85+
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
86+
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
87+
88+
uint32_t const m = a_tensors.size(0);
89+
uint32_t const n = out_tensors.size(1);
90+
91+
if (m <= 64) {
92+
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
93+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
94+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
95+
per_out_ch);
96+
} else if (n >= 8192) {
97+
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
98+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
99+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
100+
per_out_ch);
101+
} else {
102+
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
103+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
104+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
105+
per_out_ch);
106+
}
107+
}
108+
} // namespace
109+
110+
void dispatch_moe_mm_sm100(
111+
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
112+
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
113+
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
114+
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
115+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
116+
bool per_act_token, bool per_out_ch) {
117+
if (out_tensors.dtype() == torch::kBFloat16) {
118+
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
119+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
120+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
121+
per_out_ch);
122+
} else {
123+
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
124+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
125+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
126+
per_out_ch);
127+
}
128+
}
129+
130+
void cutlass_moe_mm_sm100(
131+
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
132+
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
133+
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
134+
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
135+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
136+
bool per_act_token, bool per_out_ch) {
137+
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
138+
expert_offsets, problem_sizes, a_strides, b_strides,
139+
c_strides, per_act_token, per_out_ch);
140+
}

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu renamed to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ struct sm90_fp8_config_default {
2121
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
2222
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
2323
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
24+
using ArchTag = cutlass::arch::Sm90;
2425

2526
using Cutlass3xGemm =
26-
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
27-
KernelSchedule, EpilogueSchedule>;
27+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
28+
ClusterShape, KernelSchedule, EpilogueSchedule>;
2829
};
2930

3031
template <typename InType, typename OutType,
@@ -38,10 +39,12 @@ struct sm90_fp8_config_M4 {
3839
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
3940
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
4041
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
42+
using ArchTag = cutlass::arch::Sm90;
4143

4244
using Cutlass3xGemm =
43-
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
44-
KernelSchedule, EpilogueSchedule, true>;
45+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
46+
ClusterShape, KernelSchedule, EpilogueSchedule,
47+
true>;
4548
};
4649

4750
template <typename InType, typename OutType,
@@ -55,10 +58,12 @@ struct sm90_fp8_config_M64 {
5558
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
5659
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
5760
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
61+
using ArchTag = cutlass::arch::Sm90;
5862

5963
using Cutlass3xGemm =
60-
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
61-
KernelSchedule, EpilogueSchedule, true>;
64+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
65+
ClusterShape, KernelSchedule, EpilogueSchedule,
66+
true>;
6267
};
6368

6469
template <typename InType, typename OutType,
@@ -72,10 +77,11 @@ struct sm90_fp8_config_K8192 {
7277
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
7378
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
7479
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
80+
using ArchTag = cutlass::arch::Sm90;
7581

7682
using Cutlass3xGemm =
77-
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
78-
KernelSchedule, EpilogueSchedule>;
83+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
84+
ClusterShape, KernelSchedule, EpilogueSchedule>;
7985
};
8086

8187
template <typename InType, typename OutType,
@@ -89,10 +95,11 @@ struct sm90_fp8_config_N8192 {
8995
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
9096
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
9197
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
98+
using ArchTag = cutlass::arch::Sm90;
9299

93100
using Cutlass3xGemm =
94-
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
95-
KernelSchedule, EpilogueSchedule>;
101+
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
102+
ClusterShape, KernelSchedule, EpilogueSchedule>;
96103
};
97104

98105
template <typename InType, typename OutType>
@@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90(
112119
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
113120
"B tensors must be of type float8_e4m3fn.");
114121

115-
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
116-
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
117-
118122
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
119123
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
120124
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
190190
static_cast<int32_t*>(problem_sizes2.data_ptr()),
191191
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
192192
k);
193-
}
193+
}

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(
4141

4242
#endif
4343

44+
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
45+
void cutlass_moe_mm_sm100(
46+
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
47+
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
48+
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
49+
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
50+
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
51+
bool per_act_token, bool per_out_ch);
52+
#endif
53+
4454
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
4555
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
4656
torch::Tensor const& b,
@@ -130,22 +140,25 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
130140
// and at least SM90 (Hopper)
131141

132142
#if defined CUDA_VERSION
133-
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
134-
return CUDA_VERSION >= 12000;
135-
} else if (cuda_device_capability >= 100) {
143+
if (cuda_device_capability >= 100) {
136144
return CUDA_VERSION >= 12080;
145+
} else if (cuda_device_capability >= 90) {
146+
return CUDA_VERSION >= 12000;
137147
}
138148
#endif
139149

140150
return false;
141151
}
142152

143153
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
144-
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
145-
// and SM90 (Hopper)
154+
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
155+
// or CUDA 12.8 and SM100 (Blackwell)
146156

147157
#if defined CUDA_VERSION
148-
if (cuda_device_capability == 90) {
158+
if (cuda_device_capability >= 100) {
159+
return CUDA_VERSION >= 12080;
160+
}
161+
if (cuda_device_capability >= 90) {
149162
return CUDA_VERSION >= 12030;
150163
}
151164
#endif
@@ -234,16 +247,26 @@ void cutlass_moe_mm(
234247
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
235248
bool per_act_token, bool per_out_ch) {
236249
int32_t version_num = get_sm_version_num();
250+
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
251+
if (version_num >= 100) {
252+
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
253+
expert_offsets, problem_sizes, a_strides, b_strides,
254+
c_strides, per_act_token, per_out_ch);
255+
return;
256+
}
257+
#endif
237258
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
238-
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
239-
expert_offsets, problem_sizes, a_strides, b_strides,
240-
c_strides, per_act_token, per_out_ch);
241-
return;
259+
if (version_num >= 90) {
260+
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
261+
expert_offsets, problem_sizes, a_strides, b_strides,
262+
c_strides, per_act_token, per_out_ch);
263+
return;
264+
}
242265
#endif
243266
TORCH_CHECK_NOT_IMPLEMENTED(
244267
false,
245268
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
246-
". Required capability: 90");
269+
". Required capability: 90 or 100");
247270
}
248271

249272
void get_cutlass_moe_mm_data(

0 commit comments

Comments
 (0)