|
| 1 | +# ============================================================================= |
| 2 | +# Configure cuBLASMp dependencies and linking for ABACUS |
| 3 | +# ============================================================================= |
| 4 | + |
| 5 | +include_guard(GLOBAL) |
| 6 | + |
| 7 | +function(abacus_setup_cublasmp target_name) |
| 8 | + add_compile_definitions(__CUBLASMP) |
| 9 | + |
| 10 | + # 1. Search for cuBLASMp library and header files |
| 11 | + # libcublasmp.so |
| 12 | + find_library(CUBLASMP_LIBRARY NAMES cublasmp |
| 13 | + HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR} |
| 14 | + PATH_SUFFIXES lib lib64 math_libs/lib math_libs/lib64) |
| 15 | + |
| 16 | + # cublasmp.h |
| 17 | + find_path(CUBLASMP_INCLUDE_DIR NAMES cublasmp.h |
| 18 | + HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR} |
| 19 | + PATH_SUFFIXES include math_libs/include) |
| 20 | + |
| 21 | + if(NOT CUBLASMP_LIBRARY OR NOT CUBLASMP_INCLUDE_DIR) |
| 22 | + message(FATAL_ERROR |
| 23 | + "cuBLASMp not found. Please ensure CUBLASMP_PATH is set correctly." |
| 24 | + ) |
| 25 | + endif() |
| 26 | + |
| 27 | + message(STATUS "Found cuBLASMp: ${CUBLASMP_LIBRARY}") |
| 28 | + |
| 29 | + # 2. Version validation by parsing header macros |
| 30 | + set(CUBLASMP_VERSION_STR "") |
| 31 | + set(CUBLASMP_VERSION_HEADER "${CUBLASMP_INCLUDE_DIR}/cublasmp.h") |
| 32 | + |
| 33 | + if(EXISTS "${CUBLASMP_VERSION_HEADER}") |
| 34 | + # Extract version lines using regular expressions from cublasmp.h |
| 35 | + file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MAJOR_LINE |
| 36 | + REGEX "^#define[ \t]+CUBLASMP_VER_MAJOR[ \t]+[0-9]+") |
| 37 | + file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MINOR_LINE |
| 38 | + REGEX "^#define[ \t]+CUBLASMP_VER_MINOR[ \t]+[0-9]+") |
| 39 | + file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_PATCH_LINE |
| 40 | + REGEX "^#define[ \t]+CUBLASMP_VER_PATCH[ \t]+[0-9]+") |
| 41 | + |
| 42 | + # Extract numeric values from the matched strings |
| 43 | + string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MAJOR "${CUBLASMP_MAJOR_LINE}") |
| 44 | + string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MINOR "${CUBLASMP_MINOR_LINE}") |
| 45 | + string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_PATCH "${CUBLASMP_PATCH_LINE}") |
| 46 | + |
| 47 | + if(NOT CUBLASMP_VER_MAJOR STREQUAL "" |
| 48 | + AND NOT CUBLASMP_VER_MINOR STREQUAL "" |
| 49 | + AND NOT CUBLASMP_VER_PATCH STREQUAL "") |
| 50 | + set(CUBLASMP_VERSION_STR |
| 51 | + "${CUBLASMP_VER_MAJOR}.${CUBLASMP_VER_MINOR}.${CUBLASMP_VER_PATCH}") |
| 52 | + endif() |
| 53 | + endif() |
| 54 | + |
| 55 | + message(STATUS "Detected cuBLASMp version: ${CUBLASMP_VERSION_STR}") |
| 56 | + |
| 57 | + # 3. Version constraint: ABACUS requires cuBLASMp >= 0.8.0 |
| 58 | + if(CUBLASMP_VERSION_STR AND CUBLASMP_VERSION_STR VERSION_LESS "0.8.0") |
| 59 | + message(FATAL_ERROR |
| 60 | + "cuBLASMp version ${CUBLASMP_VERSION_STR} is too old. " |
| 61 | + "ABACUS requires cuBLASMp >= 0.8.0 for NCCL Symmetric Memory support." |
| 62 | + ) |
| 63 | + elseif(NOT CUBLASMP_VERSION_STR) |
| 64 | + message(WARNING "Could not detect cuBLASMp version. Proceeding cautiously.") |
| 65 | + endif() |
| 66 | + |
| 67 | + # 4. Create cublasMp::cublasMp imported target |
| 68 | + if(NOT TARGET cublasMp::cublasMp) |
| 69 | + add_library(cublasMp::cublasMp IMPORTED INTERFACE) |
| 70 | + set_target_properties(cublasMp::cublasMp PROPERTIES |
| 71 | + INTERFACE_LINK_LIBRARIES "${CUBLASMP_LIBRARY};NCCL::NCCL" |
| 72 | + INTERFACE_INCLUDE_DIRECTORIES "${CUBLASMP_INCLUDE_DIR}") |
| 73 | + endif() |
| 74 | + |
| 75 | + # 5. Link the library to the target |
| 76 | + target_link_libraries(${target_name} cublasMp::cublasMp) |
| 77 | + |
| 78 | +endfunction() |
0 commit comments