Skip to content

Commit ecb0aad

Browse files
committed
Update hegvd_op.h/cpp to use the uniform lapackConnector
1 parent 24e222f commit ecb0aad

File tree

3 files changed

+95
-82
lines changed

3 files changed

+95
-82
lines changed

source/source_base/module_external/lapack_wrapper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
#ifndef LAPACK_HPP
22
#define LAPACK_HPP
3+
4+
/// This is a wrapper of some LAPACK routines.
5+
/// Direct wrapping of standard LAPACK routines. (column major, fortran style)
6+
/// including:
7+
/// 1. hegvd: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem
8+
/// 2. heevx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
9+
/// 3. hegvx: compute the first m eigenvalues and their corresponding eigenvectors of a generalized Hermitian-definite eigenproblem
10+
/// 4. hegv: compute all the eigenvalues and eigenvectors of a generalized Hermitian-definite eigenproblem
11+
12+
313
#include <iostream>
414
extern "C"
515
{

source/source_hsolver/kernels/hegvd_op.cpp

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#include "source_hsolver/kernels/hegvd_op.h"
2+
// #include "source_base/module_external/lapack_wrapper.h"
3+
#include "source_base/module_container/base/third_party/lapack.h"
24

35
#include <algorithm>
46
#include <fstream>
57
#include <iostream>
68

9+
namespace lapackConnector = container::lapackConnector; // see "source_base/module_container/base/third_party/lapack.h"
10+
711
namespace hsolver
812
{
913
// hegvd and sygvd; dn for dense?
@@ -39,7 +43,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
3943
//===========================
4044
// calculate all eigenvalues
4145
//===========================
42-
LapackWrapper::xhegvd(1,
46+
lapackConnector::hegvd(1,
4347
'V',
4448
'U',
4549
nstart,
@@ -58,7 +62,7 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
5862

5963
if (info != 0)
6064
{
61-
std::cout << "Error: xhegvd failed, linear dependent basis functions\n"
65+
std::cout << "Error: hegvd failed, linear dependent basis functions\n"
6266
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
6367
<< ", output overlap matrix scc.txt to check\n"
6468
<< std::endl;
@@ -82,62 +86,62 @@ struct hegvd_op<T, base_device::DEVICE_CPU>
8286
}
8387
};
8488

85-
template <typename T>
86-
struct hegv_op<T, base_device::DEVICE_CPU>
87-
{
88-
using Real = typename GetTypeReal<T>::type;
89-
void operator()(const base_device::DEVICE_CPU* d,
90-
const int nbase,
91-
const int ldh,
92-
const T* hcc,
93-
T* scc,
94-
Real* eigenvalue,
95-
T* vcc)
96-
{
97-
for (int i = 0; i < nbase * ldh; i++)
98-
{
99-
vcc[i] = hcc[i];
100-
}
101-
102-
int info = 0;
103-
104-
int lwork = 2 * nbase - 1;
105-
T* work = new T[lwork];
106-
Parallel_Reduce::ZEROS(work, lwork);
107-
108-
int lrwork = 3 * nbase - 2;
109-
Real* rwork = new Real[lrwork];
110-
Parallel_Reduce::ZEROS(rwork, lrwork);
111-
112-
//===========================
113-
// calculate all eigenvalues
114-
//===========================
115-
LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);
116-
117-
if (info != 0)
118-
{
119-
std::cout << "Error: xhegv failed, linear dependent basis functions\n"
120-
<< ", wrong initialization of wavefunction, or wavefunction information loss\n"
121-
<< ", output overlap matrix scc.txt to check\n"
122-
<< std::endl;
123-
// print scc to file scc.txt
124-
std::ofstream ofs("scc.txt");
125-
for (int i = 0; i < nbase; i++)
126-
{
127-
for (int j = 0; j < nbase; j++)
128-
{
129-
ofs << scc[i * ldh + j] << " ";
130-
}
131-
ofs << std::endl;
132-
}
133-
ofs.close();
134-
}
135-
assert(0 == info);
136-
137-
delete[] work;
138-
delete[] rwork;
139-
}
140-
};
89+
// template <typename T>
90+
// struct hegv_op<T, base_device::DEVICE_CPU>
91+
// {
92+
// using Real = typename GetTypeReal<T>::type;
93+
// void operator()(const base_device::DEVICE_CPU* d,
94+
// const int nbase,
95+
// const int ldh,
96+
// const T* hcc,
97+
// T* scc,
98+
// Real* eigenvalue,
99+
// T* vcc)
100+
// {
101+
// for (int i = 0; i < nbase * ldh; i++)
102+
// {
103+
// vcc[i] = hcc[i];
104+
// }
105+
106+
// int info = 0;
107+
108+
// int lwork = 2 * nbase - 1;
109+
// T* work = new T[lwork];
110+
// Parallel_Reduce::ZEROS(work, lwork);
111+
112+
// int lrwork = 3 * nbase - 2;
113+
// Real* rwork = new Real[lrwork];
114+
// Parallel_Reduce::ZEROS(rwork, lrwork);
115+
116+
// //===========================
117+
// // calculate all eigenvalues
118+
// //===========================
119+
// LapackWrapper::xhegv(1, 'V', 'U', nbase, vcc, ldh, scc, ldh, eigenvalue, work, lwork, rwork, info);
120+
121+
// if (info != 0)
122+
// {
123+
// std::cout << "Error: xhegv failed, linear dependent basis functions\n"
124+
// << ", wrong initialization of wavefunction, or wavefunction information loss\n"
125+
// << ", output overlap matrix scc.txt to check\n"
126+
// << std::endl;
127+
// // print scc to file scc.txt
128+
// std::ofstream ofs("scc.txt");
129+
// for (int i = 0; i < nbase; i++)
130+
// {
131+
// for (int j = 0; j < nbase; j++)
132+
// {
133+
// ofs << scc[i * ldh + j] << " ";
134+
// }
135+
// ofs << std::endl;
136+
// }
137+
// ofs.close();
138+
// }
139+
// assert(0 == info);
140+
141+
// delete[] work;
142+
// delete[] rwork;
143+
// }
144+
// };
141145

142146
// heevx and syevx
143147
/**
@@ -174,7 +178,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>
174178

175179
// When lwork = -1, the demension of work will be assumed
176180
// Assume the denmension of work by output work[0]
177-
LapackWrapper::xheevx(
181+
lapackConnector::heevx(
178182
1, // ITYPE = 1: A*x = (lambda)*B*x
179183
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
180184
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -208,7 +212,7 @@ struct heevx_op<T, base_device::DEVICE_CPU>
208212
// V is the output of the function, the storage space is also (nstart * ldh), and the data size of valid V
209213
// obtained by the zhegvx operation is (nstart * nstart) and stored in zux (internal to the function). When
210214
// the function is output, the data of zux will be mapped to the corresponding position of V.
211-
LapackWrapper::xheevx(
215+
lapackConnector::heevx(
212216
1, // ITYPE = 1: A*x = (lambda)*B*x
213217
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
214218
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -267,7 +271,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
267271
int* iwork = new int[5 * nbase];
268272
int* ifail = new int[nbase];
269273

270-
LapackWrapper::xhegvx(
274+
lapackConnector::hegvx(
271275
1, // ITYPE = 1: A*x = (lambda)*B*x
272276
'V', // JOBZ = 'V': Compute eigenvalues and eigenvectors.
273277
'I', // RANGE = 'I': the IL-th through IU-th eigenvalues will be found.
@@ -297,7 +301,7 @@ struct hegvx_op<T, base_device::DEVICE_CPU>
297301
delete[] work;
298302
work = new T[lwork];
299303

300-
LapackWrapper::xhegvx(1,
304+
lapackConnector::hegvx(1,
301305
'V',
302306
'I',
303307
'U',
@@ -338,12 +342,12 @@ template struct heevx_op<std::complex<double>, base_device::DEVICE_CPU>;
338342
template struct hegvx_op<std::complex<float>, base_device::DEVICE_CPU>;
339343
template struct hegvx_op<std::complex<double>, base_device::DEVICE_CPU>;
340344

341-
template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
342-
template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
345+
// template struct hegv_op<std::complex<float>, base_device::DEVICE_CPU>;
346+
// template struct hegv_op<std::complex<double>, base_device::DEVICE_CPU>;
343347
#ifdef __LCAO
344348
template struct hegvd_op<double, base_device::DEVICE_CPU>;
345349
template struct heevx_op<double, base_device::DEVICE_CPU>;
346350
template struct hegvx_op<double, base_device::DEVICE_CPU>;
347-
template struct hegv_op<double, base_device::DEVICE_CPU>;
351+
// template struct hegv_op<double, base_device::DEVICE_CPU>;
348352
#endif
349353
} // namespace hsolver

source/source_hsolver/kernels/hegvd_op.h

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
// And will be moved to a global module(module base) later.
2727

2828
#include "source_base/macros.h"
29-
#include "source_base/module_external/lapack_wrapper.h"
3029
#include "source_base/parallel_reduce.h"
3130
#include "source_base/module_device/types.h"
3231

@@ -68,22 +67,22 @@ struct hegvd_op
6867
void operator()(const Device* d, const int nstart, const int ldh, const T* A, const T* B, Real* W, T* V);
6968
};
7069

71-
template <typename T, typename Device>
72-
struct hegv_op
73-
{
74-
using Real = typename GetTypeReal<T>::type;
75-
/// @brief HEGV computes first m eigenvalues and eigenvectors of a complex generalized
76-
/// Input Parameters
77-
/// @param d : the type of device
78-
/// @param nbase : the number of dim of the matrix
79-
/// @param ldh : the number of dmx of the matrix
80-
/// @param A : the hermitian matrix A in A x=lambda B x (col major)
81-
/// @param B : the overlap matrix B in A x=lambda B x (col major)
82-
/// Output Parameter
83-
/// @param W : calculated eigenvalues
84-
/// @param V : calculated eigenvectors (col major)
85-
void operator()(const Device* d, const int nstart, const int ldh, const T* A, T* B, Real* W, T* V);
86-
};
70+
// template <typename T, typename Device>
71+
// struct hegv_op
72+
// {
73+
// using Real = typename GetTypeReal<T>::type;
74+
// /// @brief HEGV computes first m eigenvalues and eigenvectors of a complex generalized
75+
// /// Input Parameters
76+
// /// @param d : the type of device
77+
// /// @param nbase : the number of dim of the matrix
78+
// /// @param ldh : the number of dmx of the matrix
79+
// /// @param A : the hermitian matrix A in A x=lambda B x (col major)
80+
// /// @param B : the overlap matrix B in A x=lambda B x (col major)
81+
// /// Output Parameter
82+
// /// @param W : calculated eigenvalues
83+
// /// @param V : calculated eigenvectors (col major)
84+
// void operator()(const Device* d, const int nstart, const int ldh, const T* A, T* B, Real* W, T* V);
85+
// };
8786

8887
template <typename T, typename Device>
8988
struct hegvx_op

0 commit comments

Comments
 (0)