Skip to content

Commit b0c36c3

Browse files
haozhihandyzheng
authored andcommitted
Replace the APIs for more efficient solving eigenpairs in davidson and CG-subspace (#1703)
* Add dngvd operation for LAPACK_subspace * Change gvx_op to evx_op in davidson method * Add dnevx_op & dngvd_op CUDA version
1 parent 4dc50f9 commit b0c36c3

File tree

6 files changed

+522
-44
lines changed

6 files changed

+522
-44
lines changed

source/module_base/lapack_connector.h

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,42 @@
2222
extern "C"
2323
{
2424
int ilaenv_(int* ispec,const char* name,const char* opts,
25-
const int* n1,const int* n2,const int* n3,const int* n4);
25+
const int* n1,const int* n2,const int* n3,const int* n4);
26+
27+
2628
// solve the generalized eigenproblem Ax=eBx, where A is Hermitian and complex couble
27-
// zhegv_ returns all eigenvalues while zhegvx_ returns selected ones
29+
// zhegv_ & zhegvd_ returns all eigenvalues while zhegvx_ returns selected ones
30+
31+
void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
32+
std::complex<double>* a, const int* lda, std::complex<double>* b, const int* ldb,
33+
double* w, std::complex<double>* work, int* lwork, double* rwork, int* lrwork,
34+
int* iwork, int* liwork, int* info);
35+
2836
void zhegv_(const int* itype,const char* jobz,const char* uplo,const int* n,
2937
std::complex<double>* a,const int* lda,std::complex<double>* b,const int* ldb,
3038
double* w,std::complex<double>* work,int* lwork,double* rwork,int* info);
39+
3140
void zhegvx_(const int* itype,const char* jobz,const char* range,const char* uplo,
3241
const int* n,std::complex<double> *a,const int* lda,std::complex<double> *b,
3342
const int* ldb,const double* vl,const double* vu,const int* il,
3443
const int* iu,const double* abstol,const int* m,double* w,
3544
std::complex<double> *z,const int *ldz,std::complex<double> *work,const int* lwork,
3645
double* rwork,int* iwork,int* ifail,int* info);
46+
47+
// solve the eigenproblem Ax=ex, where A is Hermitian and complex couble
48+
// zheev_ returns all eigenvalues while zheevx_ returns selected ones
49+
void zheev_(const char* jobz,const char* uplo,const int* n,std::complex<double> *a,
50+
const int* lda,double* w,std::complex<double >* work,const int* lwork,
51+
double* rwork,int* info);
52+
53+
void zheevx_(const char* jobz,const char* range,const char* uplo,
54+
const int* n,std::complex<double> *a,const int* lda,
55+
const double* vl,const double* vu,const int* il,
56+
const int* iu,const double* abstol,const int* m,double* w,
57+
std::complex<double> *z,const int *ldz,std::complex<double> *work,const int* lwork,
58+
double* rwork,int* iwork,int* ifail,int* info);
59+
60+
3761
// solve the generalized eigenproblem Ax=eBx, where A is Symmetric and real couble
3862
// dsygv_ returns all eigenvalues while dsygvx_ returns selected ones
3963
void dsygv_(const int* itype, const char* jobz,const char* uplo, const int* n,
@@ -47,10 +71,8 @@ extern "C"
4771
// solve the eigenproblem Ax=ex, where A is Symmetric and real double
4872
void dsyev_(const char* jobz,const char* uplo,const int* n,double *a,
4973
const int* lda,double* w,double* work,const int* lwork, int* info);
50-
// solve the eigenproblem Ax=ex, where A is Hermitian and complex couble
51-
void zheev_(const char* jobz,const char* uplo,const int* n,std::complex<double> *a,
52-
const int* lda,double* w,std::complex<double >* work,const int* lwork,
53-
double* rwork,int* info);
74+
75+
5476
// dsytrf_ computes the Bunch-Kaufman factorization of a double precision
5577
// symmetric matrix, while dsytri takes its output to perform martrix inversion
5678
void dsytrf_(const char* uplo, const int* n, double * a, const int* lda,
@@ -226,7 +248,34 @@ class LapackConnector
226248
return nb;
227249
}
228250

229-
// wrap function of fortran lapack routine zhegv.
251+
// wrap function of fortran lapack routine zhegvd.
252+
static inline
253+
void zhegvd(const int itype, const char jobz, const char uplo, const int n,
254+
ModuleBase::ComplexMatrix& a, const int lda,
255+
ModuleBase::ComplexMatrix& b, const int ldb, double* w,
256+
std::complex<double>* work, int lwork, double* rwork, int lrwork,
257+
int* iwork, int liwork, int info)
258+
{
259+
// Transpose the std::complex matrix to the fortran-form real-std::complex array.
260+
std::complex<double>* aux = LapackConnector::transpose(a, n, lda);
261+
std::complex<double>* bux = LapackConnector::transpose(b, n, ldb);
262+
263+
// call the fortran routine
264+
zhegvd_(&itype, &jobz, &uplo, &n,
265+
aux, &lda, bux, &ldb, w,
266+
work, &lwork, rwork, &lrwork,
267+
iwork, &liwork, &info);
268+
269+
// Transpose the fortran-form real-std::complex array to the std::complex matrix.
270+
LapackConnector::transpose(aux, a, n, lda);
271+
LapackConnector::transpose(bux, b, n, ldb);
272+
273+
// free the memory.
274+
delete[] aux;
275+
delete[] bux;
276+
}
277+
278+
// wrap function of fortran lapack routine zhegv ( ModuleBase::ComplexMatrix version ).
230279
static inline
231280
void zhegv( const int itype,const char jobz,const char uplo,const int n,ModuleBase::ComplexMatrix& a,
232281
const int lda,ModuleBase::ComplexMatrix& b,const int ldb,double* w,std::complex<double>* work,
@@ -244,20 +293,21 @@ class LapackConnector
244293
delete[] aux;
245294
delete[] bux;
246295
}
247-
// wrap function of fortran lapack routine zhegv.
296+
297+
// wrap function of fortran lapack routine zhegv ( pointer version ) .
248298
static inline
249299
void zhegv( const int itype, const char jobz, const char uplo, const int n, std::complex<double>* a,
250300
const int lda, const std::complex<double>* b, const int ldb, double* w, std::complex<double>* work,
251301
int lwork, double* rwork, int info, int ld_real)
252-
{ // Transpose the std::complex matrix to the fortran-form real-std::complex array.
302+
{
303+
// Transpose the std::complex matrix to the fortran-form real-std::complex array.
253304
std::complex<double>* aux = LapackConnector::transpose(a, n, lda, ld_real);
254305
std::complex<double>* bux = LapackConnector::transpose(b, n, ldb, ld_real);
255306

256307
// call the fortran routine
257308
zhegv_(&itype, &jobz, &uplo, &n, aux, &lda, bux, &ldb, w, work, &lwork, rwork, &info);
309+
258310
// Transpose the fortran-form real-std::complex array to the std::complex matrix.
259-
// LapackConnector::transpose(aux, a, n, lda);
260-
// LapackConnector::transpose(bux, b, n, ldb);
261311
for (int i = 0; i < n; ++i)
262312
{
263313
for (int j = 0; j < lda; ++j)
@@ -270,7 +320,7 @@ class LapackConnector
270320
delete[] bux;
271321
}
272322

273-
// wrap function of fortran lapack routine zheev.
323+
// wrap function of fortran lapack routine zhegvx ( ModuleBase::ComplexMatrix version ).
274324
static inline
275325
void zhegvx( const int itype, const char jobz, const char range, const char uplo,
276326
const int n, const ModuleBase::ComplexMatrix& a, const int lda, const ModuleBase::ComplexMatrix& b,
@@ -298,7 +348,8 @@ class LapackConnector
298348
delete[] zux;
299349

300350
}
301-
// wrap function of fortran lapack routine zheev.
351+
352+
// wrap function of fortran lapack routine zhegvx ( pointer version ).
302353
static inline
303354
void zhegvx( const int itype, const char jobz, const char range, const char uplo,
304355
const int n, const std::complex<double>* a, const int lda, const std::complex<double>* b,
@@ -331,6 +382,38 @@ class LapackConnector
331382
delete[] zux;
332383
}
333384

385+
static inline
386+
void zheevx( const int itype, const char jobz, const char range, const char uplo, const int n,
387+
const std::complex<double>* a, const int lda, const double vl, const double vu, const int il, const int iu,
388+
const double abstol, const int m, double* w, std::complex<double>* z, const int ldz,
389+
std::complex<double>* work, const int lwork, double* rwork, int* iwork, int* ifail, int info, int nbase_x)
390+
{
391+
// Transpose the std::complex matrix to the fortran-form real-std::complex array.
392+
std::complex<double>* aux = LapackConnector::transpose(a, n, lda, nbase_x);
393+
std::complex<double>* zux = new std::complex<double>[n*iu];// mohan modify 2009-08-02
394+
395+
// call the fortran routine
396+
zheevx_(&jobz, &range, &uplo, &n,
397+
aux, &lda, &vl, &vu, &il, &iu,
398+
&abstol, &m, w, zux, &ldz,
399+
work, &lwork, rwork, iwork, ifail, &info);
400+
401+
// Transpose the fortran-form real-std::complex array to the std::complex matrix
402+
for (int i = 0; i < iu; ++i)
403+
{
404+
for (int j = 0; j < n; ++j)
405+
{
406+
z[j * nbase_x + i] = zux[i*n+j];
407+
}
408+
}
409+
410+
// free the memory.
411+
delete[] aux;
412+
delete[] zux;
413+
}
414+
415+
416+
334417
// calculate the eigenvalues and eigenfunctions of a real symmetric matrix.
335418
static inline
336419
void dsygv( const int itype,const char jobz,const char uplo,const int n,ModuleBase::matrix& a,

source/module_hsolver/diago_david.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,16 +625,17 @@ void DiagoDavid<FPTYPE, Device>::diag_zhegvx(const int& n, // nbase
625625
resmem_var_op()(this->ctx, eigenvalue_gpu, this->nbase_x);
626626
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, eigenvalue_gpu, eigenvalue, this->nbase_x);
627627

628-
dngvx_op<FPTYPE, Device>()(this->ctx, n, this->nbase_x, this->hcc, this->scc, m, eigenvalue_gpu, this->vcc);
628+
dnevx_op<FPTYPE, Device>()(this->ctx, n, this->nbase_x, this->hcc, m, eigenvalue_gpu, this->vcc);
629629

630630
syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, eigenvalue, eigenvalue_gpu, this->nbase_x);
631631
delmem_var_op()(this->ctx, eigenvalue_gpu);
632632
#endif
633633
}
634634
else
635635
{
636-
dngvx_op<FPTYPE,
637-
Device>()(this->ctx, n, this->nbase_x, this->hcc, this->scc, m, this->eigenvalue, this->vcc);
636+
// dngvx_op<FPTYPE,
637+
// Device>()(this->ctx, n, this->nbase_x, this->hcc, this->scc, m, this->eigenvalue, this->vcc);
638+
dnevx_op<FPTYPE, Device>()(this->ctx, n, this->nbase_x, this->hcc, m, this->eigenvalue, this->vcc);
638639
}
639640
}
640641

source/module_hsolver/diago_iter_assist.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ void DiagoIterAssist<FPTYPE, Device>::diagH_LAPACK(
432432
//===========================
433433
// calculate all eigenvalues
434434
//===========================
435-
dngv_op<FPTYPE, Device>()(ctx, nstart, ldh, hcc, scc, res, vcc);
435+
// dngv_op<FPTYPE, Device>()(ctx, nstart, ldh, hcc, scc, res, vcc);
436+
dngvd_op<FPTYPE, Device>()(ctx, nstart, ldh, hcc, scc, res, vcc);
436437
}
437438
else {
438439
//=====================================

source/module_hsolver/include/dngvd_op.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ template <typename FPTYPE, typename Device> struct dngvx_op
1616
/// @brief DNGVX computes the first m eigenvalues ​​and their corresponding eigenvectors of
1717
/// a complex generalized Hermitian-definite eigenproblem
1818
///
19+
/// In this op, the CPU version is implemented through the `gvx` interface, and the CUDA version
20+
/// is implemented through the `gvd` interface and acquires the first m eigenpairs.
21+
/// API doc:
22+
/// 1. zhegvx: https://netlib.org/lapack/explore-html/df/d9a/group__complex16_h_eeigen_ga8ea76cbbb14edb5a22069e203fc8e8b2.html
23+
/// 2. cusolverDnZhegvd: https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdn-t-sygvd
24+
///
1925
/// Input Parameters
2026
/// @param d : the type of device
2127
/// @param nstart : the number of cols of the matrix
@@ -41,6 +47,41 @@ template <typename FPTYPE, typename Device> struct dngv_op
4147
/// @brief DNGV computes all the eigenvalues and eigenvectors of a complex generalized
4248
/// Hermitian-definite eigenproblem
4349
///
50+
/// In this op, the CPU version is implemented through the `gv` interface, and the CUDA version
51+
/// is implemented through the `gvd` interface.
52+
/// API doc:
53+
/// 1. zhegv: https://netlib.org/lapack/explore-html/df/d9a/group__complex16_h_eeigen_gaf7b790b3b89de432a423c9006c1cc1ac.html
54+
/// 2. cusolverDnZhegvd: https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdn-t-sygvd
55+
///
56+
/// Input Parameters
57+
/// @param d : the type of device
58+
/// @param nstart : the number of cols of the matrix
59+
/// @param ldh : the number of rows of the matrix
60+
/// @param A : the hermitian matrix A in A x=lambda B x (row major)
61+
/// @param B : the overlap matrix B in A x=lambda B x (row major)
62+
/// Output Parameter
63+
/// @param W : calculated eigenvalues
64+
/// @param V : calculated eigenvectors (row major)
65+
void operator()(const Device* d,
66+
const int nstart,
67+
const int ldh,
68+
const std::complex<FPTYPE>* A,
69+
const std::complex<FPTYPE>* B,
70+
double* W,
71+
std::complex<FPTYPE>* V);
72+
};
73+
74+
template <typename FPTYPE, typename Device> struct dngvd_op
75+
{
76+
/// @brief DNGVD computes all the eigenvalues and eigenvectors of a complex generalized
77+
/// Hermitian-definite eigenproblem. If eigenvectors are desired, it uses a divide and conquer algorithm.
78+
///
79+
/// In this op, the CPU version is implemented through the `gvd` interface, and the CUDA version
80+
/// is implemented through the `gvd` interface.
81+
/// API doc:
82+
/// 1. zhegvd: https://netlib.org/lapack/explore-html/df/d9a/group__complex16_h_eeigen_ga74fdf9b5a16c90d8b7a589dec5ca058a.html
83+
/// 2. cusolverDnZhegvd: https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdn-t-sygvd
84+
///
4485
/// Input Parameters
4586
/// @param d : the type of device
4687
/// @param nstart : the number of cols of the matrix
@@ -59,6 +100,36 @@ template <typename FPTYPE, typename Device> struct dngv_op
59100
std::complex<FPTYPE>* V);
60101
};
61102

103+
104+
template <typename FPTYPE, typename Device> struct dnevx_op
105+
{
106+
/// @brief DNEVX computes the first m eigenvalues ​​and their corresponding eigenvectors of
107+
/// a complex generalized Hermitian-definite eigenproblem
108+
///
109+
/// In this op, the CPU version is implemented through the `evx` interface, and the CUDA version
110+
/// is implemented through the `evd` interface and acquires the first m eigenpairs.
111+
/// API doc:
112+
/// 1. zheevx: https://netlib.org/lapack/explore-html/df/d9a/group__complex16_h_eeigen_gaabef68a9c7b10df7aef8f4fec89fddbe.html
113+
/// 2. cusolverDnZheevd: https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdn-t-syevd
114+
///
115+
/// Input Parameters
116+
/// @param d : the type of device
117+
/// @param nstart : the number of cols of the matrix
118+
/// @param ldh : the number of rows of the matrix
119+
/// @param A : the hermitian matrix A in A x=lambda B x (row major)
120+
/// Output Parameter
121+
/// @param W : calculated eigenvalues
122+
/// @param V : calculated eigenvectors (row major)
123+
void operator()(const Device* d,
124+
const int nstart,
125+
const int ldh,
126+
const std::complex<FPTYPE>* A,
127+
const int m,
128+
double* W,
129+
std::complex<FPTYPE>* V);
130+
};
131+
132+
62133
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
63134

64135
void createCUSOLVERhandle();

0 commit comments

Comments
 (0)