@@ -35,16 +35,21 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3535# sm90a
3636
3737set (SUPPORT_ARCHS)
38- if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
39- list (APPEND SUPPORT_ARCHS 9.0a)
38+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3)
39+ list (APPEND SUPPORT_ARCHS " 9.0a" )
4040endif ()
41- if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
42- list (APPEND SUPPORT_ARCHS 10.0a)
41+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.9)
42+ # CUDA 12.9 has introduced "Family-Specific Architecture Features"
43+ # this supports all compute_10x family
44+ list (APPEND SUPPORT_ARCHS "10.0f" )
45+ elseif (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
46+ list (APPEND SUPPORT_ARCHS "10.0a" )
4347endif ()
4448
4549
4650cuda_archs_loose_intersection (FLASH_MLA_ARCHS "${SUPPORT_ARCHS} " "${CUDA_ARCHS} " )
4751if (FLASH_MLA_ARCHS)
52+ message (STATUS "FlashMLA CUDA architectures: ${FLASH_MLA_ARCHS} " )
4853 set (VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS} )
4954 list (APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math" )
5055
@@ -126,7 +131,8 @@ if(FLASH_MLA_ARCHS)
126131 $<$<COMPILE_LANGUAGE :CUDA >:-UPy_LIMITED_API >
127132 $<$<COMPILE_LANGUAGE :CXX >:-UPy_LIMITED_API >)
128133else ()
129- # Create empty targets for setup.py when not targeting sm90a systems
134+ message (STATUS "FlashMLA will not compile: unsupported CUDA architecture ${CUDA_ARCHS} " )
135+ # Create empty targets for setup.py on unsupported systems
130136 add_custom_target (_flashmla_C )
131137 add_custom_target (_flashmla_extension_C )
132138endif ()
0 commit comments