Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 3f3b6b2

Browse files
authored
[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (vllm-project#5715)
1 parent a7dcc62 commit 3f3b6b2

File tree

5 files changed

+30
-13
lines changed

5 files changed

+30
-13
lines changed

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
9292
int64_t size_k, int64_t size_n,
9393
int64_t num_bits);
9494

95+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
96+
9597
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
9698
torch::Tensor const& b, torch::Tensor const& a_scales,
9799
torch::Tensor const& b_scales);

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
2525
torch::Tensor const& b_scales);
2626
#endif
2727

28+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
29+
// CUTLASS FP8 kernels need at least
30+
// CUDA 12.0 on SM90 systems (Hopper)
31+
// CUDA 12.4 on SM89 systems (Lovelace)
32+
33+
#if defined CUDA_VERSION
34+
if (cuda_device_capability >= 90) {
35+
return CUDA_VERSION >= 12000;
36+
} else if (cuda_device_capability >= 89) {
37+
return CUDA_VERSION >= 12040;
38+
}
39+
#endif
40+
41+
return false;
42+
}
43+
2844
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
2945
torch::Tensor const& b, torch::Tensor const& a_scales,
3046
torch::Tensor const& b_scales) {

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
144144
" Tensor b, Tensor a_scales,"
145145
" Tensor b_scales) -> ()");
146146
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
147+
148+
// Check if cutlass scaled_mm is supported for CUDA devices of the given
149+
// capability
150+
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
151+
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
152+
&cutlass_scaled_mm_supports_fp8);
147153
#endif
148154

149155
// Quantized GEMM for GPTQ.

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
216216

217217

218218
# cutlass
219+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
220+
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
221+
222+
219223
def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
220224
scale_b: torch.Tensor,
221225
out_dtype: Type[torch.dtype]) -> torch.Tensor:

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,8 @@
2020
def cutlass_fp8_supported() -> bool:
2121
capability = torch.cuda.get_device_capability()
2222
capability = capability[0] * 10 + capability[1]
23-
major, minor = torch.version.cuda.split(".")
24-
version = int(major) * 10 + int(minor)
25-
26-
# CUTLASS FP8 kernels need at least
27-
# CUDA 12.0 on SM90 systems (Hopper)
28-
# CUDA 12.4 on SM89 systems (Lovelace)
29-
gpu_is_supported = False
30-
if capability >= 90:
31-
gpu_is_supported = version > 120
32-
elif capability >= 89:
33-
gpu_is_supported = version > 124
34-
35-
return gpu_is_supported
23+
24+
return ops.cutlass_scaled_mm_supports_fp8(capability)
3625

3726

3827
class Fp8Config(QuantizationConfig):

0 commit comments

Comments
 (0)