Skip to content

Commit 5234dc7

Browse files
johnnynunezAidyn-ADrStone71
authored
[NVIDIA] Blackwell Family (vllm-project#24673)
Signed-off-by: Johnny <[email protected]> Signed-off-by: johnnynunez <[email protected]> Signed-off-by: Johnny <[email protected]> Signed-off-by: Salvatore Cena <[email protected]> Co-authored-by: Aidyn-A <[email protected]> Co-authored-by: Salvatore Cena <[email protected]>
1 parent 3b7c20a commit 5234dc7

File tree

5 files changed

+66
-22
lines changed

5 files changed

+66
-22
lines changed

CMakeLists.txt

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ find_package(Torch REQUIRED)
8686
# Supported NVIDIA architectures.
8787
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
8888
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
89+
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
90+
set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0")
91+
elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
8992
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
9093
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
9194
else()
@@ -175,6 +178,15 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
175178
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
176179
endif()
177180

181+
#
182+
# Set compression mode for CUDA >=13.x.
183+
#
184+
if(VLLM_GPU_LANG STREQUAL "CUDA" AND
185+
DEFINED CMAKE_CUDA_COMPILER_VERSION AND
186+
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
187+
list(APPEND VLLM_GPU_FLAGS "--compress-mode=size")
188+
endif()
189+
178190
#
179191
# Set CUDA include flags for CXX compiler.
180192
#
@@ -270,7 +282,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
270282
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
271283

272284
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
273-
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
285+
set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use")
274286

275287
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
276288
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -305,7 +317,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
305317
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
306318
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
307319
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
308-
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
309320
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
310321
"csrc/cutlass_extensions/common.cpp"
311322
"csrc/quantization/fp8/per_token_group_quant.cu")
@@ -440,7 +451,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
440451

441452
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
442453
# CUDA 12.8 or later
443-
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
454+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
455+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
456+
else()
457+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}")
458+
endif()
444459
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
445460
set(SRCS
446461
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
@@ -470,7 +485,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
470485

471486
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
472487
# require CUDA 12.8 or later
473-
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
488+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
489+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
490+
else()
491+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
492+
endif()
474493
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
475494
set(SRCS
476495
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
@@ -550,7 +569,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
550569

551570
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
552571
# CUDA 12.8 or later
553-
cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
572+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
573+
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
574+
else()
575+
cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}")
576+
endif()
554577
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
555578
set(SRCS
556579
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
@@ -569,7 +592,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
569592
endif()
570593

571594
# FP4 Archs and flags
572-
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
595+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
596+
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
597+
else()
598+
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
599+
endif()
573600
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
574601
set(SRCS
575602
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
@@ -591,7 +618,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
591618
endif()
592619

593620
# CUTLASS MLA Archs and flags
594-
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
621+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
622+
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
623+
else()
624+
cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
625+
endif()
595626
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
596627
set(SRCS
597628
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
@@ -635,7 +666,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
635666
endif()
636667
endif()
637668

638-
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
669+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
670+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f" "${CUDA_ARCHS}")
671+
else()
672+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
673+
endif()
639674
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
640675
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
641676
set_gencode_flags_for_srcs(
@@ -656,7 +691,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
656691
endif()
657692

658693
# moe_data.cu is used by all CUTLASS MoE kernels.
659-
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
694+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
695+
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
696+
else()
697+
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
698+
endif()
660699
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
661700
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
662701
set_gencode_flags_for_srcs(
@@ -675,7 +714,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
675714
endif()
676715
endif()
677716

678-
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
717+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
718+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
719+
else()
720+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
721+
endif()
679722
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
680723
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
681724
set_gencode_flags_for_srcs(

cmake/utils.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
310310
list(REMOVE_DUPLICATES _PTX_ARCHS)
311311
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
312312

313-
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
314-
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
313+
# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
314+
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
315315
set(_CUDA_ARCHS)
316316
foreach(_arch ${_SRC_CUDA_ARCHS})
317-
if(_arch MATCHES "\\a$")
317+
if(_arch MATCHES "[af]$")
318318
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
319-
string(REPLACE "a" "" _base "${_arch}")
319+
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
320320
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
321321
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
322322
list(APPEND _CUDA_ARCHS "${_arch}")

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
231231
} else {
232232
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
233233
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
234-
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
234+
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
235235
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
236236
out, a, b, a_scales, b_scales);
237237
}
@@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
245245
} else {
246246
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
247247
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
248-
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
248+
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
249249
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
250250
out, a, b, a_scales, b_scales);
251251
}
@@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
259259
} else {
260260
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
261261
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
262-
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
262+
Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm,
263263
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
264264
out, a, b, a_scales, b_scales);
265265
}
@@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
271271
// TMA epilogue isn't compatible with Swap A/B
272272
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
273273
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
274-
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
274+
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
275275
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
276276
out, a, b, a_scales, b_scales);
277277
}
278278
}
279279

280-
} // namespace vllm
280+
} // namespace vllm

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,4 @@ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
133133
}
134134
}
135135

136-
} // namespace vllm
136+
} // namespace vllm

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
6767
std::optional<torch::Tensor> const& bias);
6868
#endif
6969

70-
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
71-
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
70+
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
71+
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
72+
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
7273
void get_cutlass_moe_mm_data_caller(
7374
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
7475
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,

0 commit comments

Comments
 (0)