Skip to content

Commit 265dac7

Browse files
authored
The United Connector of LAPACK (#6579)
* Add hegvx routine * Update hegvd_op.h/cpp to use the uniform lapackConnector * Update docs for LAPACK connector API files * Update docs for LAPACK connector API files
1 parent 757bb58 commit 265dac7

File tree

5 files changed

+224
-87
lines changed

5 files changed

+224
-87
lines changed

source/source_base/module_container/base/third_party/lapack.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
/**
2+
* @file lapack.h
3+
* @brief This is a direct wrapper of some LAPACK routines.
4+
* \b Column-Major version.
5+
* Direct wrapping of standard LAPACK routines. (Column-Major, fortran style)
6+
*
7+
* @warning For Row-major version, please refer to \c source/source_base/module_external/lapack_connector.h.
8+
*
9+
* @note
10+
* Some slight modification are made to fit the C++ style for overloading purpose.
11+
* You can find some function with different parameter list than the original LAPACK routine.
12+
* And some of these parameters are not referred in the function body. They are included just to
13+
* ensure the same parameter list for overloaded functions with a uniform name.
14+
*/
15+
116
#ifndef BASE_THIRD_PARTY_LAPACK_H_
217
#define BASE_THIRD_PARTY_LAPACK_H_
318

@@ -10,6 +25,10 @@
1025
#include <base/third_party/hipsolver.h>
1126
#endif
1227

28+
/// This is a wrapper of some LAPACK routines.
29+
/// Direct wrapping of standard LAPACK routines. (column major, fortran style)
30+
/// with some slight modification to fit the C++ style for overloading purpose.
31+
1332
//Naming convention of lapack subroutines : ammxxx, where
1433
//"a" specifies the data type:
1534
// - d stands for double
@@ -46,6 +65,27 @@ void chegvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
4665
std::complex<float>* work, int* lwork, float* rwork, int* lrwork,
4766
int* iwork, int* liwork, int* info);
4867

68+
void ssygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
69+
const int* n, float* A, const int* lda, float* B, const int* ldb,
70+
const float* vl, const float* vu, const int* il, const int* iu,
71+
const float* abstol, const int* m, float* w, float* Z, const int* ldz,
72+
float* work, const int* lwork, int* iwork, int* ifail, int* info);
73+
void dsygvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
74+
const int* n, double* A, const int* lda, double* B, const int* ldb,
75+
const double* vl, const double* vu, const int* il, const int* iu,
76+
const double* abstol, const int* m, double* w, double* Z, const int* ldz,
77+
double* work, const int* lwork, int* iwork, int* ifail, int* info);
78+
void chegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
79+
const int* n, std::complex<float>* A, const int* lda, std::complex<float>* B, const int* ldb,
80+
const float* vl, const float* vu, const int* il, const int* iu,
81+
const float* abstol, const int* m, float* w, std::complex<float>* Z, const int* ldz,
82+
std::complex<float>* work, const int* lwork, float* rwork, int* iwork, int* ifail, int* info);
83+
void zhegvx_(const int* itype, const char* jobz, const char* range, const char* uplo,
84+
const int* n, std::complex<double>* A, const int* lda, std::complex<double>* B, const int* ldb,
85+
const double* vl, const double* vu, const int* il, const int* iu,
86+
const double* abstol, const int* m, double* w, std::complex<double>* Z, const int* ldz,
87+
std::complex<double>* work, const int* lwork, double* rwork, int* iwork, int* ifail, int* info);
88+
4989
void zhegvd_(const int* itype, const char* jobz, const char* uplo, const int* n,
5090
std::complex<double>* a, const int* lda,
5191
const std::complex<double>* b, const int* ldb, double* w,
@@ -190,6 +230,68 @@ void hegvd(const int itype, const char jobz, const char uplo, const int n,
190230
iwork, &liwork, &info);
191231
}
192232

233+
// Note
234+
// rwork is only needed for complex version
235+
// and we include rwork in the function parameter list
236+
// for simplicity of function overloading
237+
// and unification of function parameter list
238+
static inline
239+
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
240+
float* a, const int lda, float* b, const int ldb,
241+
const float vl, const float vu, const int il, const int iu, const float abstol,
242+
const int m, float* w, float* z, const int ldz,
243+
float* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info)
244+
{
245+
ssygvx_(&itype, &jobz, &range, &uplo, &n,
246+
a, &lda, b, &ldb,
247+
&vl, &vu, &il, &iu,
248+
&abstol, &m, w, z, &ldz,
249+
work, &lwork, iwork, ifail, &info);
250+
}
251+
252+
static inline
253+
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
254+
double* a, const int lda, double* b, const int ldb,
255+
const double vl, const double vu, const int il, const int iu, const double abstol,
256+
const int m, double* w, double* z, const int ldz,
257+
double* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
258+
{
259+
dsygvx_(&itype, &jobz, &range, &uplo, &n,
260+
a, &lda, b, &ldb,
261+
&vl, &vu, &il, &iu,
262+
&abstol, &m, w, z, &ldz,
263+
work, &lwork, iwork, ifail, &info);
264+
}
265+
266+
static inline
267+
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
268+
std::complex<float>* a, const int lda, std::complex<float>* b, const int ldb,
269+
const float vl, const float vu, const int il, const int iu, const float abstol,
270+
const int m, float* w, std::complex<float>* z, const int ldz,
271+
std::complex<float>* work, const int lwork, float* rwork, int* iwork, int* ifail, int& info)
272+
{
273+
chegvx_(&itype, &jobz, &range, &uplo, &n,
274+
a, &lda, b, &ldb,
275+
&vl, &vu, &il, &iu,
276+
&abstol, &m, w, z, &ldz,
277+
work, &lwork, rwork, iwork, ifail, &info);
278+
}
279+
280+
static inline
281+
void hegvx(const int itype, const char jobz, const char range, const char uplo, const int n,
282+
std::complex<double>* a, const int lda, std::complex<double>* b, const int ldb,
283+
const double vl, const double vu, const int il, const int iu, const double abstol,
284+
const int m, double* w, std::complex<double>* z, const int ldz,
285+
std::complex<double>* work, const int lwork, double* rwork, int* iwork, int* ifail, int& info)
286+
{
287+
zhegvx_(&itype, &jobz, &range, &uplo, &n,
288+
a, &lda, b, &ldb,
289+
&vl, &vu, &il, &iu,
290+
&abstol, &m, w, z, &ldz,
291+
work, &lwork, rwork, iwork, ifail, &info);
292+
}
293+
294+
193295
// wrap function of fortran lapack routine zheevx.
194296
static inline
195297
void heevx( const int itype, const char jobz, const char range, const char uplo, const int n,

source/source_base/module_external/lapack_connector.h

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
1-
#ifndef LAPACKCONNECTOR_HPP
2-
#define LAPACKCONNECTOR_HPP
1+
/**
2+
* @file lapack_connector.h
3+
*
4+
* @brief This is a wrapper of some LAPACK routines.
5+
* \b Row-Major version.
6+
*
7+
* @warning MAY BE DEPRECATED IN THE FUTURE.
8+
* @warning For Column-major version, please refer to \c source/source_base/module_container/base/third_party/lapack.h.
9+
*
10+
* @note
11+
* !!! Note that
12+
* This wrapper is a <b>C++ style</b> wrapper of LAPACK routines,
13+
* i.e., assuming that the input matrices are in \b row-major order.
14+
* The data layout in C++ is row-major, C style,
15+
* while the original LAPACK is column-major, fortran style.
16+
* (ModuleBase::ComplexMatrix is in row-major order)
17+
* The wrapper will do the data transformation between
18+
* row-major and column-major order automatically.
19+
*
20+
*/
21+
22+
#ifndef LAPACK_CONNECTOR_HPP
23+
#define LAPACK_CONNECTOR_HPP
324

425
#include <new>
526
#include <stdexcept>
@@ -11,8 +32,10 @@
1132

1233
//Naming convention of lapack subroutines : ammxxx, where
1334
//"a" specifies the data type:
14-
// - d stands for double
15-
// - z stands for complex double
35+
// - s stands for float
36+
// - d stands for double
37+
// - c stands for complex float
38+
// - z stands for complex double
1639
//"mm" specifies the type of matrix, for example:
1740
// - he stands for hermitian
1841
// - sy stands for symmetric
@@ -468,4 +491,4 @@ class LapackConnector
468491
cherk_(&uplo_changed, &trans_changed, &n, &k, &alpha, A, &lda, &beta, C, &ldc);
469492
}
470493
};
471-
#endif // LAPACKCONNECTOR_HPP
494+
#endif // LAPACK_CONNECTOR_HPP

source/source_base/module_external/lapack_wrapper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
#ifndef LAPACK_HPP
22
#define LAPACK_HPP
3+
4+
/// This is a wrapper of some LAPACK routines.
5+
/// Direct wrapping of standard LAPACK routines. (column major, fortran style)
6+
/// including:
7+
/// 1. hegvd: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem
8+
/// 2. heevx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
9+
/// 3. hegvx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
10+
/// 4. hegv: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem
11+
12+
313
#include <iostream>
414
extern "C"
515
{

source/source_hsolver/kernels/hegvd_op.cpp

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#include "source_hsolver/kernels/hegvd_op.h"
2+
#include "source_base/module_container/base/third_party/lapack.h"
23

34
#include <algorithm>
45
#include <fstream>
56
#include <iostream>
67

8+
namespace lapackConnector = container::lapackConnector; // see "source_base/module_container/base/third_party/lapack.h"
9+
710
namespace hsolver
811
{
912
// hegvd and sygvd; dn for dense?
@@ -39,7 +42,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
3942
//===========================
4043
// calculate all eigenvalues
4144
//===========================
42-
LapackWrapper::xhegvd(1,
45+
lapackConnector::hegvd(1,
4346
'V',
4447
'U',
4548
nstart,
@@ -58,7 +61,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
5861

5962
if (info != 0)
6063
{
61-
std::cout << "Error: xhegvd failed, linear dependent basis functions\n"
64+
std::cout << "Error: hegvd failed, linear dependent basis functions\n"
6265
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
6366
<< ", output overlap matrix scc.txt to check\n"
6467
<< std::endl;
@@ -82,62 +85,62 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
8285
}
8386
};
8487

85-
template <typename T>
86-
struct hegv_op<T, base_device::DEVICE_CPU>
87-
{
88-
using Real = typename GetTypeReal<T>::type;
89-
void operator()(const base_device::DEVICE_CPU* d,
90-
const int nbase,
91-
const int ldh,
92-
const T* hcc,
93-
T* scc,
94-
Real* eigenvalue,
95-
T* vcc)
96-
{
97-
for (int i = 0; i < nbase * ldh; i++)
98-
{
99-
vcc[i] = hcc[i];
100-
}
101-
102-
int info = 0;
103-
104-
int lwork = 2 * nbase - 1;
105-
T* work = new T[lwork];
106-
Parallel_Reduce::ZEROS(work, lwork);
107-
108-
int lrwork = 3 * nbase - 2;
109-
Real* rwork = new Real[lrwork];
110-
Parallel_Reduce::ZEROS(rwork, lrwork);
111-
112-
//===========================
113-
// calculate all eigenvalues
114-
//===========================
115-
LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);
116-
117-
if (info != 0)
118-
{
119-
std::cout << "Error: xhegv failed, linear dependent basis functions\n"
120-
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
121-
<< ", output overlap matrix scc.txt to check\n"
122-
<< std::endl;
123-
// print scc to file scc.txt
124-
std::ofstream ofs("scc.txt");
125-
for (int i = 0; i < nbase; i++)
126-
{
127-
for (int j = 0; j < nbase; j++)
128-
{
129-
ofs << scc[i * ldh + j] << " ";
130-
}
131-
ofs << std::endl;
132-
}
133-
ofs.close();
134-
}
135-
assert(0 == info);
136-
137-
delete[] work;
138-
delete[] rwork;
139-
}
140-
};
88+
// template <typename T>
89+
// struct hegv_op<T, base_device::DEVICE_CPU>
90+
// {
91+
// using Real = typename GetTypeReal<T>::type;
92+
// void operator()(const base_device::DEVICE_CPU* d,
93+
// const int nbase,
94+
// const int ldh,
95+
// const T* hcc,
96+
// T* scc,
97+
// Real* eigenvalue,
98+
// T* vcc)
99+
// {
100+
// for (int i = 0; i < nbase * ldh; i++)
101+
// {
102+
// vcc[i] = hcc[i];
103+
// }
104+
105+
// int info = 0;
106+
107+
// int lwork = 2 * nbase - 1;
108+
// T* work = new T[lwork];
109+
// Parallel_Reduce::ZEROS(work, lwork);
110+
111+
// int lrwork = 3 * nbase - 2;
112+
// Real* rwork = new Real[lrwork];
113+
// Parallel_Reduce::ZEROS(rwork, lrwork);
114+
115+
// //===========================
116+
// // calculate all eigenvalues
117+
// //===========================
118+
// LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);
119+
120+
// if (info != 0)
121+
// {
122+
// std::cout << "Error: xhegv failed, linear dependent basis functions\n"
123+
// << ", wrong initialization of wavefunction, or wavefunction information loss\n"
124+
// << ", output overlap matrix scc.txt to check\n"
125+
// << std::endl;
126+
// // print scc to file scc.txt
127+
// std::ofstream ofs("scc.txt");
128+
// for (int i = 0; i < nbase; i++)
129+
// {
130+
// for (int j = 0; j < nbase; j++)
131+
// {
132+
// ofs << scc[i * ldh + j] << " ";
133+
// }
134+
// ofs << std::endl;
135+
// }
136+
// ofs.close();
137+
// }
138+
// assert(0 == info);
139+
140+
// delete[] work;
141+
// delete[] rwork;
142+
// }
143+
// };
141144

142145
// heevx and syevx
143146
/**
@@ -174,7 +177,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>
174177

175178
// When lwork = -1, the demension of work will be assumed
176179
// Assume the denmension of work by output work[0]
177-
LapackWrapper::xheevx(
180+
lapackConnector::heevx(
178181
1, // ITYPE = 1: A*x = (lambda)*B*x
179182
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
180183
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -208,7 +211,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>
208211
// V is the output of the function, the storage space is also (nstart * ldh), and the data size of valid V
209212
// obtained by the zhegvx operation is (nstart * nstart) and stored in zux (internal to the function). When
210213
// the function is output, the data of zux will be mapped to the corresponding position of V.
211-
LapackWrapper::xheevx(
214+
lapackConnector::heevx(
212215
1, // ITYPE = 1: A*x = (lambda)*B*x
213216
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
214217
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -267,7 +270,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
267270
int* iwork = new int[5 * nbase];
268271
int* ifail = new int[nbase];
269272

270-
LapackWrapper::xhegvx(
273+
lapackConnector::hegvx(
271274
1, // ITYPE = 1: A*x = (lambda)*B*x
272275
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
273276
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -297,7 +300,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
297300
delete[] work;
298301
work = new T[lwork];
299302

300-
LapackWrapper::xhegvx(1,
303+
lapackConnector::hegvx(1,
301304
'V',
302305
'I',
303306
'U',
@@ -338,12 +341,12 @@ template struct heevx_op<std::complex<double>, base_device::DEVICE_CPU>;
338341
template struct hegvx_op<std::complex<float>, base_device::DEVICE_CPU>;
339342
template struct hegvx_op<std::complex<double>, base_device::DEVICE_CPU>;
340343

341-
template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
342-
template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
344+
// template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
345+
// template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
343346
#ifdef __LCAO
344347
template struct hegvd_op<double, base_device::DEVICE_CPU>;
345348
template struct heevx_op<double, base_device::DEVICE_CPU>;
346349
template struct hegvx_op<double, base_device::DEVICE_CPU>;
347-
template struct hegv_op<double, base_device::DEVICE_CPU>;
350+
// template struct hegv_op<double, base_device::DEVICE_CPU>;
348351
#endif
349352
} // namespace hsolver

0 commit comments

Comments
 (0)