Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a7dcc62

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Kernel] Update Cutlass int8 kernel configs for SM80 (vllm-project#5275)
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent ad137cd commit a7dcc62

File tree

3 files changed

+123
-16
lines changed

3 files changed

+123
-16
lines changed

csrc/quantization/cutlass_w8a8/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "cutlass/cutlass.h"
4+
#include <climits>
45

56
/**
67
* Helper function for checking CUTLASS errors
@@ -10,3 +11,9 @@
1011
TORCH_CHECK(status == cutlass::Status::kSuccess, \
1112
cutlassGetStatusString(status)) \
1213
}
14+
15+
inline uint32_t next_pow_2(uint32_t const num) {
16+
if (num <= 1) return num;
17+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
18+
}
19+

csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,120 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
250250
CUTLASS_CHECK(status);
251251
}
252252

253+
template <typename InType, typename OutType,
254+
template <typename, typename> typename Epilogue>
255+
struct sm80_config_default {
256+
// This config is used in 2 cases,
257+
// - M in (128, inf)
258+
// - M in (64, 128] and N >= 8192
259+
static_assert(std::is_same<InType, int8_t>());
260+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
261+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
262+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
263+
using Cutlass2xGemm =
264+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
265+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
266+
};
267+
268+
template <typename InType, typename OutType,
269+
template <typename, typename> typename Epilogue>
270+
struct sm80_config_M64 {
271+
// This config is used in 2 cases,
272+
// - M in (32, 64]
273+
// - M in (64, 128] and N < 8192
274+
static_assert(std::is_same<InType, int8_t>());
275+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
276+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
277+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
278+
using Cutlass2xGemm =
279+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
280+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
281+
};
282+
283+
template <typename InType, typename OutType,
284+
template <typename, typename> typename Epilogue>
285+
struct sm80_config_M32 {
286+
// M in (16, 32]
287+
static_assert(std::is_same<InType, int8_t>());
288+
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
289+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
290+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
291+
using Cutlass2xGemm =
292+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
293+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
294+
};
295+
296+
template <typename InType, typename OutType,
297+
template <typename, typename> typename Epilogue>
298+
struct sm80_config_M16 {
299+
// M in [1, 16]
300+
static_assert(std::is_same<InType, int8_t>());
301+
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
302+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
303+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
304+
using Cutlass2xGemm =
305+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
306+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
307+
};
308+
253309
} // namespace
254310

311+
template <typename InType, typename OutType,
312+
template <typename, typename> typename Epilogue,
313+
typename... EpilogueArgs>
314+
void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
315+
torch::Tensor const& b,
316+
EpilogueArgs&&... args) {
317+
static_assert(std::is_same<InType, int8_t>());
318+
TORCH_CHECK(a.dtype() == torch::kInt8);
319+
TORCH_CHECK(b.dtype() == torch::kInt8);
320+
321+
using Cutlass2xGemmDefault =
322+
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
323+
using Cutlass2xGemmM128BigN =
324+
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
325+
using Cutlass2xGemmM128SmallN =
326+
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
327+
using Cutlass2xGemmM64 =
328+
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
329+
using Cutlass2xGemmM32 =
330+
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
331+
using Cutlass2xGemmM16 =
332+
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
333+
334+
uint32_t const m = a.size(0);
335+
uint32_t const mp2 =
336+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
337+
if (mp2 <= 16) {
338+
// M in [1, 16]
339+
return cutlass_gemm_caller<Cutlass2xGemmM16>(
340+
out, a, b, std::forward<EpilogueArgs>(args)...);
341+
} else if (mp2 <= 32) {
342+
// M in (16, 32]
343+
return cutlass_gemm_caller<Cutlass2xGemmM32>(
344+
out, a, b, std::forward<EpilogueArgs>(args)...);
345+
} else if (mp2 <= 64) {
346+
// M in (32, 64]
347+
return cutlass_gemm_caller<Cutlass2xGemmM64>(
348+
out, a, b, std::forward<EpilogueArgs>(args)...);
349+
} else if (mp2 <= 128) {
350+
// M in (64, 128]
351+
uint32_t const n = out.size(1);
352+
bool const small_n = n < 8192;
353+
if (small_n) {
354+
return cutlass_gemm_caller<Cutlass2xGemmM128SmallN>(
355+
out, a, b, std::forward<EpilogueArgs>(args)...);
356+
} else {
357+
return cutlass_gemm_caller<Cutlass2xGemmM128BigN>(
358+
out, a, b, std::forward<EpilogueArgs>(args)...);
359+
}
360+
} else {
361+
// M in (128, inf)
362+
return cutlass_gemm_caller<Cutlass2xGemmDefault>(
363+
out, a, b, std::forward<EpilogueArgs>(args)...);
364+
}
365+
}
366+
255367
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
256368
torch::Tensor const& b,
257369
torch::Tensor const& a_scales,
@@ -288,20 +400,13 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
288400
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
289401
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
290402

291-
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
292-
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
293-
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
294-
295403
if (out.dtype() == torch::kBFloat16) {
296-
return cutlass_gemm_caller<cutlass_2x_gemm<
297-
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
298-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
299-
out, a, b, a_scales, b_scales);
404+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
405+
ScaledEpilogue>(out, a, b, a_scales,
406+
b_scales);
300407
} else {
301408
TORCH_CHECK(out.dtype() == torch::kFloat16);
302-
return cutlass_gemm_caller<cutlass_2x_gemm<
303-
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
304-
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
409+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, ScaledEpilogue>(
305410
out, a, b, a_scales, b_scales);
306411
}
307412
}

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ using namespace cute;
4444

4545
namespace {
4646

47-
uint32_t next_pow_2(uint32_t const num) {
48-
if (num <= 1) return num;
49-
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
50-
}
51-
5247
// A wrapper for the GEMM kernel that is used to guard against compilation on
5348
// architectures that will never use the kernel. The purpose of this is to
5449
// reduce the size of the compiled binary.

0 commit comments

Comments
 (0)