Skip to content

Commit 4790ed2

Browse files
haxushudarelbeida
andauthored
Add: use cuSolver for LCAO in GPU (#799)
Co-authored-by: Xia, Yu <[email protected]>
1 parent 54091f2 commit 4790ed2

30 files changed

+1944
-53
lines changed

CMakeLists.txt

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ project(ABACUS
1313

1414
option(ENABLE_DEEPKS "Enable DeePKS functionality" OFF)
1515
option(ENABLE_LIBXC "Enable LibXC functionality" OFF)
16-
option(USE_CUDA "Enable support to CUDA." OFF)
16+
option(USE_CUDA "Enable support to CUDA for PW." OFF)
17+
option(USE_CUSOLVER_LCAO "Enable support to CUSOLVER for LCAO." OFF)
1718
option(USE_ROCM "Enable support to ROCm." OFF)
1819
option(USE_OPENMP " Enable OpenMP in abacus." ON)
1920
option(ENABLE_ASAN "Enable AddressSanitizer" OFF)
@@ -58,21 +59,32 @@ set(CMAKE_CXX_STANDARD 11)
5859
include(CheckLanguage)
5960
check_language(CUDA)
6061
if(CMAKE_CUDA_COMPILER)
61-
if(NOT DEFINED USE_CUDA)
62-
message("CUDA components detected. \nWill build the CUDA version of ABACUS.")
63-
set(USE_CUDA ON)
62+
if(NOT DEFINED USE_CUDA OR NOT DEFINED USE_CUSOLVER_LCAO)
63+
if (NOT DEFINED USE_CUDA AND NOT DEFINED USE_CUSOLVER_LCAO)
64+
message("CUDA components detected. \nWill build the CUDA for PW version of ABACUS by default.")
65+
set(USE_CUDA ON)
66+
set(USE_CUSOLVER_LCAO OFF)
67+
elseif (NOT DEFINED USE_CUDA)
68+
set(USE_CUDA OFF)
69+
else()
70+
set(USE_CUSOLVER_LCAO OFF)
71+
endif()
6472
else()
65-
if(NOT USE_CUDA)
66-
message(WARNING "CUDA components detected, but USE_CUDA set to OFF. \nNOT building CUDA version of ABACUS.")
73+
if(NOT USE_CUDA AND NOT USE_CUSOLVER_LCAO)
74+
message(WARNING "CUDA components detected, but both USE_CUDA and USE_CUSOLVER_LCAO set to OFF. \nNOT building CUDA version of ABACUS.")
75+
elseif (USE_CUDA AND USE_CUSOLVER_LCAO)
76+
message(FATAL_ERROR "USE_CUDA and USE_CUSOLVER_LCAO set, but now they not allowed to coexist.")
6777
endif()
6878
endif()
6979
else() # CUDA not found
70-
if (USE_CUDA)
71-
message(FATAL_ERROR "USE_CUDA set but no CUDA components found.")
80+
if (USE_CUDA OR USE_CUSOLVER_LCAO)
81+
message(FATAL_ERROR "USE_CUDA or USE_CUSOLVER_LCAO set but no CUDA components found.")
7282
set(USE_CUDA OFF)
83+
set(USE_CUSOLVER_LCAO OFF)
7384
endif()
7485
endif()
75-
if(USE_CUDA)
86+
87+
if(USE_CUDA OR USE_CUSOLVER_LCAO)
7688
set(CMAKE_CXX_STANDARD 14)
7789
set(CMAKE_CXX_EXTENSIONS ON)
7890
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@@ -92,11 +104,18 @@ if(USE_CUDA)
92104
60 # P100
93105
70 # V100
94106
75 # T4
107+
80 # A100
95108
)
96109
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
97-
add_compile_definitions(__CUDA)
110+
if (USE_CUDA)
111+
add_compile_definitions(__CUDA)
112+
endif()
113+
if (USE_CUSOLVER_LCAO)
114+
add_compile_definitions(__CUSOLVER_LCAO)
115+
endif()
98116
endif()
99117

118+
100119
# Warning: CMake add support to HIP in version 3.21. This is rather a new version.
101120
# Use cmake with AMD-ROCm: https://rocmdocs.amd.com/en/latest/Installation_Guide/Using-CMake-with-AMD-ROCm.html
102121
if(USE_ROCM)

docs/features.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ equation. For PW basis, there are CG and Blocked Davidson methods for solving th
7474
equation for each basis.
7575

7676
- PW: ks_solver = ‘cg’ or ‘dav’
77-
- LCAO: ks_solver = ‘hpseps’ , ‘genelpa’ or ‘lapack’
77+
- LCAO: ks_solver = ‘hpseps’ , ‘genelpa’ , ‘scalapack_gvx’ or 'cusolver'
7878
- LCAO_in_PW: ks_solver = ‘lapack’
7979

8080
If you set ks_solver=‘hpseps’ for basis_type=‘pw’, the program will be stopped with an error

docs/input-main.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ calculations.
441441
- genelpa: This method should be used if you choose localized orbitals.
442442
- hpseps: old method, still used.
443443
- lapack: lapack can be used for localized orbitals, but is only used for single processor.
444+
- cusolver: this method needs building with the cusolver component for lcao and at least one gpu is available.
444445

445446
If you set ks_solver=`hpseps` for basis_type=`pw`, the program will be stopped with an error message:
446447

docs/install.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ You can also choose to build with which components.
7575
```bash
7676
cmake -B build -DUSE_LIBXC=1 -DUSE_CUDA=1
7777
```
78+
```bash
79+
cmake -B build -DUSE_CUSOLVER_LCAO=1
80+
```
7881

7982
If Libxc is not installed in standard path (i.e. installed with a custom prefix path), you may add the installation prefix of `FindLibxc.cmake` to `CMAKE_MODULE_PATH` environment variable, or set `Libxc_DIR` to the directory containing the file.
8083

source/input.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2485,6 +2485,12 @@ void Input::Check(void)
24852485
{
24862486
ModuleBase::WARNING_QUIT("Input", "not ready for linear_scaling method in lcao .");
24872487
}
2488+
else if (ks_solver == "cusolver")
2489+
{
2490+
#ifndef __MPI
2491+
ModuleBase::WARNING_QUIT("Input","Cusolver can not be used for series version.");
2492+
#endif
2493+
}
24882494
else
24892495
{
24902496
ModuleBase::WARNING_QUIT("Input", "please check the ks_solver parameter!");
@@ -2666,7 +2672,7 @@ void Input::Check(void)
26662672
if (!(calculation == "nscf"))
26672673
ModuleBase::WARNING_QUIT("Input", "calculate berry phase, please set calculation = nscf");
26682674
}
2669-
else if (basis_type == "lcao" && (ks_solver == "genelpa" || ks_solver == "scalapack_gvx"))
2675+
else if (basis_type == "lcao" && ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver")
26702676
{
26712677
if (!(calculation == "nscf"))
26722678
ModuleBase::WARNING_QUIT("Input", "calculate berry phase, please set calculation = nscf");

source/module_base/global_function.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,16 @@ double ddot_real(
326326
const std::complex<double>* psi_R,
327327
const bool reduce = true) ;
328328

329+
//==========================================================
330+
// GLOBAL FUNCTION :
331+
// NAME : IS_COLUMN_MAJOR_KS_SOLVER
332+
// check ks_solver requires column major or not
333+
//==========================================================
334+
static inline bool IS_COLUMN_MAJOR_KS_SOLVER()
335+
{
336+
return GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx" || GlobalV::KS_SOLVER=="cusolver";
337+
}
338+
329339
}//namespace GlobalFunc
330340
}//namespace ModuleBase
331341

source/module_deepks/LCAO_deepks_vdelta.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void LCAO_Deepks::add_v_delta(const UnitCell_pseudo &ucell,
113113

114114
int iic;
115115

116-
if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") // save the matrix as column major format
116+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER())
117117
{
118118
iic=iw1_local+iw2_local*nrow;
119119
}
@@ -142,7 +142,7 @@ void LCAO_Deepks::check_v_delta(const int nrow, const int ncol)
142142
for (int icol=0;icol<ncol;icol++)
143143
{
144144
int iic;
145-
if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") // save the matrix as column major format
145+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER())
146146
{
147147
iic=irow+icol*nrow;
148148
}
@@ -429,7 +429,7 @@ void LCAO_Deepks::cal_e_delta_band_k(const std::vector<ModuleBase::ComplexMatrix
429429
if (mu >= 0 && nu >= 0)
430430
{
431431
int iic;
432-
if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") // save the matrix as column major format
432+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER())
433433
{
434434
iic=mu+nu*nrow;
435435
}
@@ -452,4 +452,4 @@ void LCAO_Deepks::cal_e_delta_band_k(const std::vector<ModuleBase::ComplexMatrix
452452
return;
453453
}
454454

455-
#endif
455+
#endif

source/module_deepks/test/nnr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void test_deepks::folding_nnr(const Test_Deepks::K_Vectors &kv)
212212

213213
if(nu<0)continue;
214214
int iic;
215-
if(GlobalV::KS_SOLVER=="genelpa" || GlobalV::KS_SOLVER=="scalapack_gvx") // save the matrix as column major format
215+
if (ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER() )
216216
{
217217
iic=mu+nu*ParaO.nrow;
218218
}
@@ -232,4 +232,4 @@ void test_deepks::folding_nnr(const Test_Deepks::K_Vectors &kv)
232232
} // end T1
233233
assert(index==this->nnr);
234234
}
235-
}
235+
}

source/module_orbital/ORB_control.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ void ORB_control::setup_2d_division(std::ofstream& ofs_running,
167167

168168
// (1) calculate nrow, ncol, nloc.
169169
if (ks_solver == "genelpa" || ks_solver == "hpseps" || ks_solver == "scalpack"
170-
|| ks_solver == "selinv" || ks_solver == "scalapack_gvx")
170+
|| ks_solver == "selinv" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver")
171171
{
172172
ofs_running << " divide the H&S matrix using 2D block algorithms." << std::endl;
173173
#ifdef __MPI

source/module_orbital/parallel_orbitals.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void ORB_control::set_trace(std::ofstream& ofs_running)
9898
}
9999
#ifdef __MPI
100100
else if(ks_solver=="scalpack" || ks_solver=="genelpa" || ks_solver=="hpseps"
101-
|| ks_solver=="selinv" || ks_solver=="scalapack_gvx") //xiaohui add 2013-09-02
101+
|| ks_solver=="selinv" || ks_solver=="scalapack_gvx" || ks_solver=="cusolver") //xiaohui add 2013-09-02
102102
{
103103
// ofs_running << " nrow=" << nrow << std::endl;
104104
for (int irow=0; irow< pv->nrow; irow++)
@@ -205,6 +205,9 @@ void ORB_control::divide_HS_2d
205205
// get the 2D index of computer.
206206
pv->dim0 = (int)sqrt((double)dsize); //mohan update 2012/01/13
207207
//while (GlobalV::NPROC_IN_POOL%dim0!=0)
208+
209+
if (ks_solver=="cusolver") pv->dim0 = 1; // Xu Shu add 2022-03-25
210+
208211
while (dsize%pv->dim0!=0)
209212
{
210213
pv->dim0 = pv->dim0 - 1;
@@ -227,6 +230,8 @@ void ORB_control::divide_HS_2d
227230
{
228231
pv->nb = nb2d; // mohan add 2010-06-28
229232
}
233+
234+
if (ks_solver=="cusolver") pv->nb = 1; // Xu Shu add 2022-03-25
230235
ModuleBase::GlobalFunc::OUT(ofs_running,"nb2d", pv->nb);
231236

232237
this->set_parameters(ofs_running, ofs_warning);
@@ -252,7 +257,7 @@ void ORB_control::divide_HS_2d
252257
pv->nloc = pv->MatrixInfo.col_num * pv->MatrixInfo.row_num;
253258

254259
// init blacs context for genelpa
255-
if(ks_solver=="genelpa" || ks_solver=="scalapack_gvx")
260+
if (ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver")
256261
{
257262
pv->blacs_ctxt = cart2blacs(pv->comm_2D, pv->dim0, pv->dim1,
258263
nlocal, nbands, pv->nb, pv->nrow, pv->desc, pv->desc_wfc);

0 commit comments

Comments
 (0)