Skip to content

Commit 511e81e

Browse files
authored
[BUILD] use sm_100f when compiling flashmla to fix support on sm103 (vllm-project#30705)
Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
1 parent a182be4 commit 511e81e

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,21 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3535
# sm90a
3636

3737
set(SUPPORT_ARCHS)
38-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
39-
list(APPEND SUPPORT_ARCHS 9.0a)
38+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
39+
list(APPEND SUPPORT_ARCHS "9.0a")
4040
endif()
41-
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
42-
list(APPEND SUPPORT_ARCHS 10.0a)
41+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
42+
# CUDA 12.9 has introduced "Family-Specific Architecture Features"
43+
# this supports all compute_10x family
44+
list(APPEND SUPPORT_ARCHS "10.0f")
45+
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
46+
list(APPEND SUPPORT_ARCHS "10.0a")
4347
endif()
4448

4549

4650
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
4751
if(FLASH_MLA_ARCHS)
52+
message(STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS}")
4853
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
4954
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
5055

@@ -126,7 +131,8 @@ if(FLASH_MLA_ARCHS)
126131
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
127132
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
128133
else()
129-
# Create empty targets for setup.py when not targeting sm90a systems
134+
message(STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS}")
135+
# Create empty targets for setup.py on unsupported systems
130136
add_custom_target(_flashmla_C)
131137
add_custom_target(_flashmla_extension_C)
132138
endif()

0 commit comments

Comments
 (0)