Skip to content

Commit 8efe9f5

Browse files
authored
Feature: Multi-GPU support for RT-TDDFT (#7026)
1 parent 81c5067 commit 8efe9f5

26 files changed

Lines changed: 1668 additions & 763 deletions

CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,28 @@ if(USE_CUDA)
449449
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE)
450450
endif()
451451
if (ENABLE_CUSOLVERMP)
452-
# Keep cuSolverMp discovery/linking logic in a dedicated module.
452+
# Keep cuSOLVERMp discovery/linking logic in a dedicated module.
453453
include(cmake/SetupCuSolverMp.cmake)
454454
abacus_setup_cusolvermp(${ABACUS_BIN_NAME})
455455
endif()
456+
if (ENABLE_CUBLASMP)
457+
# Enforcement 1: cuBLASMp requires cuSOLVERMp to be enabled
458+
if (NOT ENABLE_CUSOLVERMP)
459+
message(FATAL_ERROR
460+
"ENABLE_CUBLASMP is set to ON, but ENABLE_CUSOLVERMP is OFF."
461+
"In ABACUS, cuBLASMp support requires cuSOLVERMp to be enabled simultaneously."
462+
"Please set -DENABLE_CUSOLVERMP=ON.")
463+
endif()
464+
# Enforcement 2: cuBLASMp 0.8.0+ is incompatible with CAL backend
465+
# Note: _use_cal is defined inside abacus_setup_cusolvermp
466+
if (_use_cal)
467+
message(FATAL_ERROR
468+
"cuBLASMp 0.8.0+ requires NCCL Symmetric Memory, but cuSOLVERMp is using CAL backend."
469+
"Please upgrade cuSOLVERMp to >= 0.7.0 to use NCCL for both.")
470+
endif()
471+
include(cmake/SetupCuBlasMp.cmake)
472+
abacus_setup_cublasmp(${ABACUS_BIN_NAME})
473+
endif()
456474
endif()
457475
endif()
458476

cmake/SetupCuBlasMp.cmake

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# =============================================================================
2+
# Configure cuBLASMp dependencies and linking for ABACUS
3+
# =============================================================================
4+
5+
include_guard(GLOBAL)
6+
7+
function(abacus_setup_cublasmp target_name)
8+
add_compile_definitions(__CUBLASMP)
9+
10+
# 1. Search for cuBLASMp library and header files
11+
# libcublasmp.so
12+
find_library(CUBLASMP_LIBRARY NAMES cublasmp
13+
HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR}
14+
PATH_SUFFIXES lib lib64 math_libs/lib math_libs/lib64)
15+
16+
# cublasmp.h
17+
find_path(CUBLASMP_INCLUDE_DIR NAMES cublasmp.h
18+
HINTS ${CUBLASMP_PATH} ${NVHPC_ROOT_DIR}
19+
PATH_SUFFIXES include math_libs/include)
20+
21+
if(NOT CUBLASMP_LIBRARY OR NOT CUBLASMP_INCLUDE_DIR)
22+
message(FATAL_ERROR
23+
"cuBLASMp not found. Please ensure CUBLASMP_PATH is set correctly."
24+
)
25+
endif()
26+
27+
message(STATUS "Found cuBLASMp: ${CUBLASMP_LIBRARY}")
28+
29+
# 2. Version validation by parsing header macros
30+
set(CUBLASMP_VERSION_STR "")
31+
set(CUBLASMP_VERSION_HEADER "${CUBLASMP_INCLUDE_DIR}/cublasmp.h")
32+
33+
if(EXISTS "${CUBLASMP_VERSION_HEADER}")
34+
# Extract version lines using regular expressions from cublasmp.h
35+
file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MAJOR_LINE
36+
REGEX "^#define[ \t]+CUBLASMP_VER_MAJOR[ \t]+[0-9]+")
37+
file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_MINOR_LINE
38+
REGEX "^#define[ \t]+CUBLASMP_VER_MINOR[ \t]+[0-9]+")
39+
file(STRINGS "${CUBLASMP_VERSION_HEADER}" CUBLASMP_PATCH_LINE
40+
REGEX "^#define[ \t]+CUBLASMP_VER_PATCH[ \t]+[0-9]+")
41+
42+
# Extract numeric values from the matched strings
43+
string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MAJOR "${CUBLASMP_MAJOR_LINE}")
44+
string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_MINOR "${CUBLASMP_MINOR_LINE}")
45+
string(REGEX MATCH "([0-9]+)" CUBLASMP_VER_PATCH "${CUBLASMP_PATCH_LINE}")
46+
47+
if(NOT CUBLASMP_VER_MAJOR STREQUAL ""
48+
AND NOT CUBLASMP_VER_MINOR STREQUAL ""
49+
AND NOT CUBLASMP_VER_PATCH STREQUAL "")
50+
set(CUBLASMP_VERSION_STR
51+
"${CUBLASMP_VER_MAJOR}.${CUBLASMP_VER_MINOR}.${CUBLASMP_VER_PATCH}")
52+
endif()
53+
endif()
54+
55+
message(STATUS "Detected cuBLASMp version: ${CUBLASMP_VERSION_STR}")
56+
57+
# 3. Version constraint: ABACUS requires cuBLASMp >= 0.8.0
58+
if(CUBLASMP_VERSION_STR AND CUBLASMP_VERSION_STR VERSION_LESS "0.8.0")
59+
message(FATAL_ERROR
60+
"cuBLASMp version ${CUBLASMP_VERSION_STR} is too old. "
61+
"ABACUS requires cuBLASMp >= 0.8.0 for NCCL Symmetric Memory support."
62+
)
63+
elseif(NOT CUBLASMP_VERSION_STR)
64+
message(WARNING "Could not detect cuBLASMp version. Proceeding cautiously.")
65+
endif()
66+
67+
# 4. Create cublasMp::cublasMp imported target
68+
if(NOT TARGET cublasMp::cublasMp)
69+
add_library(cublasMp::cublasMp IMPORTED INTERFACE)
70+
set_target_properties(cublasMp::cublasMp PROPERTIES
71+
INTERFACE_LINK_LIBRARIES "${CUBLASMP_LIBRARY};NCCL::NCCL"
72+
INTERFACE_INCLUDE_DIRECTORIES "${CUBLASMP_INCLUDE_DIR}")
73+
endif()
74+
75+
# 5. Link the library to the target
76+
target_link_libraries(${target_name} cublasMp::cublasMp)
77+
78+
endfunction()

cmake/SetupCuSolverMp.cmake

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# =============================================================================
2-
# Configure cuSolverMp dependencies and linking for ABACUS
2+
# Configure cuSOLVERMp dependencies and linking for ABACUS
33
# =============================================================================
44

55
include_guard(GLOBAL)
66

77
function(abacus_setup_cusolvermp target_name)
88
add_compile_definitions(__CUSOLVERMP)
99

10-
# Find cuSolverMp first, then decide communicator backend.
10+
# Find cuSOLVERMp first, then decide communicator backend.
1111
find_library(CUSOLVERMP_LIBRARY NAMES cusolverMp
1212
HINTS ${CAL_CUSOLVERMP_PATH} ${NVHPC_ROOT_DIR}
1313
PATH_SUFFIXES lib lib64 math_libs/lib math_libs/lib64)
@@ -18,11 +18,11 @@ function(abacus_setup_cusolvermp target_name)
1818

1919
if(NOT CUSOLVERMP_LIBRARY OR NOT CUSOLVERMP_INCLUDE_DIR)
2020
message(FATAL_ERROR
21-
"cusolverMp not found. Set CUSOLVERMP_PATH or NVHPC_ROOT_DIR."
21+
"cuSOLVERMp not found. Set CUSOLVERMP_PATH or NVHPC_ROOT_DIR."
2222
)
2323
endif()
2424

25-
message(STATUS "Found cusolverMp: ${CUSOLVERMP_LIBRARY}")
25+
message(STATUS "Found cuSOLVERMp: ${CUSOLVERMP_LIBRARY}")
2626

2727
set(CUSOLVERMP_VERSION_STR "")
2828
set(CUSOLVERMP_VERSION_HEADER "${CUSOLVERMP_INCLUDE_DIR}/cusolverMp.h")
@@ -47,27 +47,30 @@ function(abacus_setup_cusolvermp target_name)
4747
# Check minimum version requirement (>= 0.4.0)
4848
if(CUSOLVERMP_VERSION_STR AND CUSOLVERMP_VERSION_STR VERSION_LESS "0.4.0")
4949
message(FATAL_ERROR
50-
"cuSolverMp version ${CUSOLVERMP_VERSION_STR} is too old. "
51-
"ABACUS requires cuSolverMp >= 0.4.0 (NVIDIA HPC SDK >= 23.5). "
50+
"cuSOLVERMp version ${CUSOLVERMP_VERSION_STR} is too old. "
51+
"ABACUS requires cuSOLVERMp >= 0.4.0 (NVIDIA HPC SDK >= 23.5). "
5252
"Please upgrade your NVIDIA HPC SDK installation."
5353
)
5454
endif()
5555

56-
# Auto-select communicator backend by cuSolverMp version.
57-
# cuSolverMp < 0.7.0 -> CAL, otherwise -> NCCL.
56+
# Auto-select communicator backend by cuSOLVERMp version.
57+
# cuSOLVERMp < 0.7.0 -> CAL, otherwise -> NCCL.
5858
set(_use_cal OFF)
5959
if(CUSOLVERMP_VERSION_STR AND CUSOLVERMP_VERSION_STR VERSION_LESS "0.7.0")
6060
set(_use_cal ON)
6161
message(STATUS
62-
"Detected cuSolverMp ${CUSOLVERMP_VERSION_STR} (< 0.7.0). Using CAL backend.")
62+
"Detected cuSOLVERMp ${CUSOLVERMP_VERSION_STR} (< 0.7.0). Using CAL backend.")
6363
elseif(CUSOLVERMP_VERSION_STR)
6464
message(STATUS
65-
"Detected cuSolverMp ${CUSOLVERMP_VERSION_STR} (>= 0.7.0). Using NCCL backend.")
65+
"Detected cuSOLVERMp ${CUSOLVERMP_VERSION_STR} (>= 0.7.0). Using NCCL backend.")
6666
elseif(NOT CUSOLVERMP_VERSION_STR)
6767
message(WARNING
68-
"Unable to detect cuSolverMp version from header. Using NCCL backend by default.")
68+
"Unable to detect cuSOLVERMp version from header. Using NCCL backend by default.")
6969
endif()
7070

71+
# Raise the variable to the caller's scope
72+
set(_use_cal ${_use_cal} PARENT_SCOPE)
73+
7174
# Backend selection:
7275
# - _use_cal=ON -> cal communicator backend
7376
# - _use_cal=OFF -> NCCL communicator backend

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ ESolver_KS_LCAO_TDDFT<TR, Device>::ESolver_KS_LCAO_TDDFT()
3030
if (ct_device_type == ct::DeviceType::GpuDevice)
3131
{
3232
use_tensor = true;
33-
use_lapack = true;
33+
if (PARAM.inp.ks_solver != "cusolvermp")
34+
{
35+
use_lapack = true;
36+
}
3437
}
3538
}
3639

@@ -235,21 +238,22 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::hamilt2rho_single(UnitCell& ucell,
235238
{
236239
if (istep >= TD_info::estep_shift + 1)
237240
{
238-
module_rt::Evolve_elec<Device>::solve_psi(istep,
239-
PARAM.inp.nbands,
240-
PARAM.globalv.nlocal,
241-
this->kv.get_nks(),
242-
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
243-
this->pv,
244-
this->psi,
245-
this->psi_laststep,
246-
this->Hk_laststep,
247-
this->Sk_laststep,
248-
this->pelec->ekb,
249-
GlobalV::ofs_running,
250-
PARAM.inp.propagator,
251-
use_tensor,
252-
use_lapack);
241+
module_rt::Evolve_elec<Device>::solve_psi(
242+
istep,
243+
PARAM.inp.nbands,
244+
PARAM.globalv.nlocal,
245+
this->kv.get_nks(),
246+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt),
247+
this->pv,
248+
this->psi,
249+
this->psi_laststep,
250+
this->Hk_laststep,
251+
this->Sk_laststep,
252+
this->pelec->ekb,
253+
GlobalV::ofs_running,
254+
PARAM.inp.propagator,
255+
use_tensor,
256+
use_lapack);
253257
}
254258
this->weight_dm_rho(ucell);
255259
}
@@ -346,11 +350,18 @@ void ESolver_KS_LCAO_TDDFT<TR, Device>::iter_finish(UnitCell& ucell,
346350
{
347351
if (use_tensor && use_lapack)
348352
{
349-
elecstate::cal_edm_tddft_tensor_lapack<Device>(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
353+
elecstate::cal_edm_tddft_tensor_lapack<Device>(
354+
this->pv,
355+
this->dmat,
356+
this->kv,
357+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
350358
}
351359
else
352360
{
353-
elecstate::cal_edm_tddft(this->pv, this->dmat, this->kv, static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
361+
elecstate::cal_edm_tddft(this->pv,
362+
this->dmat,
363+
this->kv,
364+
static_cast<hamilt::Hamilt<std::complex<double>>*>(this->p_hamilt));
354365
}
355366
}
356367
}

source/source_esolver/esolver_ks_lcao_tddft.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
#include "esolver_ks.h"
44
#include "esolver_ks_lcao.h"
55
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
6-
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
6+
#include "source_lcao/module_rt/boundary_fix.h"
7+
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
8+
#include "source_lcao/module_rt/kernels/cublasmp_context.h"
79
#include "source_lcao/module_rt/td_info.h"
810
#include "source_lcao/module_rt/velocity_op.h"
9-
#include "source_lcao/module_rt/boundary_fix.h"
1011

1112
namespace ModuleESolver
1213
{
@@ -51,6 +52,7 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
5152
//! Control heterogeneous computing of the TDDFT solver
5253
bool use_tensor = false;
5354
bool use_lapack = false;
55+
CublasMpResources cublas_res;
5456

5557
// Control the device type for Hk_laststep and Sk_laststep
5658
// Set to CPU temporarily, should wait for further GPU development

source/source_lcao/module_rt/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ if(ENABLE_LCAO)
2222
list(APPEND objects
2323
kernels/cuda/snap_psibeta_kernel.cu
2424
kernels/cuda/snap_psibeta_gpu.cu
25+
kernels/cuda/norm_psi_kernel.cu
26+
kernels/cuda/band_energy_kernel.cu
2527
)
2628
endif()
2729

0 commit comments

Comments
 (0)