Skip to content

Commit 523085c

Browse files
committed
test generalized eigensolver
1 parent 0b010e0 commit 523085c

File tree

5 files changed

+153
-16
lines changed

5 files changed

+153
-16
lines changed

source/module_base/scalapack_connector.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,20 @@ extern "C"
109109
double* w, std::complex<double>* z, const int* iz, const int* jz, const int* descz,
110110
std::complex<double>* work, const int* lwork, double* rwork, const int* lrwork, int* info);
111111

112-
void pzgetri_(
112+
void pzheevd_(const char* jobz, const char* uplo, const int* n,
113+
std::complex<double>* a, const int* ia, const int* ja, const int* desca,
114+
double* w, std::complex<double>* z, const int* iz, const int* jz, const int* descz,
115+
std::complex<double>* work, const int* lwork, double* rwork, const int* lrwork, int* iwork, const int* liwork, int* info);
116+
117+
void pzheevx_(const char* jobz, const char* range, const char* uplo, const int* n,
118+
std::complex<double>* a, const int* ia, const int* ja, const int* desca,
119+
const double* vl, const double* vu, const int* il, const int* iu, const double* abstol,
120+
int* m, int* nz, double* w, const double* orfac,
121+
std::complex<double>* z, const int* iz, const int* jz, const int* descz,
122+
std::complex<double>* work, const int* lwork, double* rwork, const int* lrwork, int* iwork, const int* liwork,
123+
int* ifail, int* iclustr, double* gap, int* info);
124+
125+
void pzgetri_(
113126
const int *n,
114127
const std::complex<double> *A, const int *ia, const int *ja, const int *desca,
115128
int *ipiv, const std::complex<double> *work, const int *lwork, const int *iwork, const int *liwork, const int *info);

source/module_lr/utils/lr_util.cpp

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ namespace LR_Util
184184
void diag_scalapack(const int& n, double* mat, double* eigval, double* eigvec, const int(&desc)[9])
185185
{
186186
ModuleBase::TITLE("LR_Util", "diag_scalapack<double>");
187-
char jobz = 'V', uplo = 'U';
187+
const char jobz = 'V', uplo = 'U';
188188
const int minus_one = -1;
189189
const int i1 = 1;
190190
int info = 0;
@@ -204,24 +204,106 @@ namespace LR_Util
204204
void diag_scalapack(const int& n, std::complex<double>* mat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9])
205205
{
206206
ModuleBase::TITLE("LR_Util", "diag_lapack<complex<double>>");
207-
char jobz = 'V', uplo = 'U';
207+
const char jobz = 'V', uplo = 'U';
208208
const int minus_one = -1;
209209
const int i1 = 1;
210210
int info = 0;
211-
std::complex<double>lwork_tmp(0., 0.);
212-
double lrwork_tmp = 0.0;
213-
pzheev_(&jobz, &uplo, &n,
211+
std::vector<std::complex<double>> work(1, 0.0);
212+
std::vector<double>rwork(1, 0.0);
213+
// pzheev_(&jobz, &uplo, &n,
214+
// mat, &i1, &i1, desc,
215+
// eigval, eigvec, &i1, &i1, desc,
216+
// work.data(), &minus_one, rwork.data(), &minus_one, &info); // get the optimal workspace size
217+
/// try pzheevd
218+
// int liwork = 0;
219+
// pzheevd_(&jobz, &uplo, &n,
220+
// mat, &i1, &i1, desc,
221+
// eigval, eigvec, &i1, &i1, desc,
222+
// &lwork_tmp, &minus_one, &lrwork_tmp, &minus_one, &liwork, &minus_one, &info); // get the optimal workspace size
223+
224+
// try pzheevx
225+
const char range = 'A';
226+
const double zero = 0.0;
227+
double abstol = 0.0;
228+
int nz = n;
229+
std::vector<int> iwork(1, 0);
230+
std::vector<int> ifail(n, 0);
231+
std::vector<int> iclustr(2 * GlobalV::DSIZE);
232+
std::vector<double> gap(GlobalV::DSIZE);
233+
pzheevx_(&jobz, &range, &uplo, &n,
214234
mat, &i1, &i1, desc,
215-
eigval, eigvec, &i1, &i1, desc,
216-
&lwork_tmp, &minus_one, &lrwork_tmp, &minus_one, &info); // get the optimal workspace size
217-
const int lwork = lwork_tmp.real();
218-
const int lrwork = lrwork_tmp;
219-
std::vector<std::complex<double>> work(lwork);
220-
std::vector<double>rwork(lrwork);
221-
pzheev_(&jobz, &uplo, &n,
235+
&zero, &zero, &i1, &i1, &zero,
236+
&nz, &nz, eigval, &zero,
237+
eigvec, &i1, &i1, desc,
238+
work.data(), &minus_one, rwork.data(), &minus_one, iwork.data(), &minus_one,
239+
ifail.data(), iclustr.data(), gap.data(), &info);
240+
241+
const int lwork = work.at(0).real();
242+
work.resize(lwork);
243+
const int lrwork = rwork.at(0);
244+
rwork.resize(lrwork);
245+
const int liwork = iwork.at(0);
246+
iwork.resize(liwork);
247+
// std::cout << "pzheevx: query result: lwork=" << work.at(0) << ", lrwork=" << rwork.at(0) << ", liwork=" << iwork.at(0) << std::endl;
248+
249+
// pzheev_(&jobz, &uplo, &n,
250+
// mat, &i1, &i1, desc,
251+
// eigval, eigvec, &i1, &i1, desc,
252+
// work.data(), &lwork, rwork.data(), &lrwork, &info);
253+
// std::vector<int> iwork(liwork);
254+
// pzheevd_(&jobz, &uplo, &n,
255+
// mat, &i1, &i1, desc,
256+
// eigval, eigvec, &i1, &i1, desc,
257+
// work.data(), &lwork, rwork.data(), &lrwork, iwork.data(), &liwork, &info);
258+
pzheevx_(&jobz, &range, &uplo, &n,
222259
mat, &i1, &i1, desc,
223-
eigval, eigvec, &i1, &i1, desc,
224-
work.data(), &lwork, rwork.data(), &lrwork, &info);
260+
&zero, &zero, &i1, &i1, &zero,
261+
&nz, &nz, eigval, &zero,
262+
eigvec, &i1, &i1, desc,
263+
work.data(), &lwork, rwork.data(), &lrwork, iwork.data(), &liwork,
264+
ifail.data(), iclustr.data(), gap.data(), &info);
265+
if (info) { std::cout << "ERROR: Scalapack solver, info=" << info << std::endl; }
266+
}
267+
268+
void diag_scalapack(const int& n, std::complex<double>* hmat, std::complex<double>* const smat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9])
269+
{
270+
ModuleBase::TITLE("LR_Util", "diag_lapack<complex<double>>");
271+
const char jobz = 'V', uplo = 'U', range = 'A';
272+
int minus_one = -1;
273+
const int i1 = 1;
274+
const double zero = 0.0;
275+
int info = 0;
276+
double abstol = 0.0;
277+
int nz = n;
278+
std::vector<std::complex<double>> work(1, 0.0);
279+
std::vector<double>rwork(1, 0.0);
280+
std::vector<int> iwork(1, 0);
281+
std::vector<int> ifail(n, 0);
282+
std::vector<int> iclustr(2 * GlobalV::DSIZE);
283+
std::vector<double> gap(GlobalV::DSIZE);
284+
pzhegvx_(&i1, &jobz, &range, &uplo, &n,
285+
hmat, &i1, &i1, desc, smat, &i1, &i1, desc,
286+
&zero, &zero, &i1, &i1, &zero,
287+
&nz, &nz, eigval, &zero,
288+
eigvec, &i1, &i1, desc,
289+
work.data(), &minus_one, rwork.data(), &minus_one, iwork.data(), &minus_one,
290+
ifail.data(), iclustr.data(), gap.data(), &info);
291+
292+
int lwork = work.at(0).real();
293+
work.resize(lwork);
294+
int lrwork = rwork.at(0);
295+
rwork.resize(lrwork);
296+
int liwork = iwork.at(0);
297+
iwork.resize(liwork);
298+
// std::cout << "pzhegvx: query result: lwork=" << work.at(0) << ", lrwork=" << rwork.at(0) << ", liwork=" << iwork.at(0) << std::endl;
299+
pzhegvx_(&i1, &jobz, &range, &uplo, &n,
300+
hmat, &i1, &i1, desc,
301+
smat, &i1, &i1, desc,
302+
&zero, &zero, &i1, &i1, &zero,
303+
&nz, &nz, eigval, &zero,
304+
eigvec, &i1, &i1, desc,
305+
work.data(), &lwork, rwork.data(), &lrwork, iwork.data(), &liwork,
306+
ifail.data(), iclustr.data(), gap.data(), &info);
225307
if (info) { std::cout << "ERROR: Scalapack solver, info=" << info << std::endl; }
226308
}
227309
#endif

source/module_lr/utils/lr_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ namespace LR_Util
108108
#ifdef __MPI
109109
void diag_scalapack(const int& n, double* mat, double* eigval, double* eigvec, const int(&desc)[9]);
110110
void diag_scalapack(const int& n, std::complex<double>* mat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9]);
111+
void diag_scalapack(const int& n, std::complex<double>* hmat, std::complex<double>* smat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9]);
111112
#endif
112113
///=================string option====================
113114
std::string tolower(const std::string& str);

source/module_lr/utils/lr_util_print.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ namespace LR_Util
201201
const int& iat2 = tmp2.first.first;
202202
const auto& R = tmp2.first.second;
203203
auto& t = tmp2.second;
204-
if (R != TC({ 0, 0, 0 })) {continue;} // for test
204+
// if (R != TC({ 0, 0, 0 })) {continue;} // for test
205205
std::cout << "iat1=" << iat1 << " iat2=" << iat2 << " R=(" << R[0] << " " << R[1] << " " << R[2] << ")\n";
206206
if (t.shape.size() == 2)
207207
{

source/module_lr/utils/test/lr_util_algorithms_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ inline void check_double_eq(std::complex<double>* data1, std::complex<double>* d
1616
EXPECT_NEAR(data1[i].imag(), data2[i].imag(), 1e-10);
1717
}
1818
};
19+
inline void check_norm_eq(double* data1, double* data2, int size)
20+
{
21+
for (int i = 0;i < size;++i)
22+
{
23+
EXPECT_NEAR(std::norm(data1[i]), std::norm(data2[i]), 1e-10);
24+
}
25+
};
1926
inline void check_norm_eq(std::complex<double>* data1, std::complex<double>* data2, int size)
2027
{
2128
for (int i = 0;i < size;++i)
@@ -201,6 +208,40 @@ TEST(LR_Util, DiagScaLapackDouble)
201208
check_norm_eq(eigvec_para.data(), eigvec_serial_local.data(), pmat.get_local_size());
202209
}
203210

211+
212+
TEST(LR_Util, DiagScaLapackGeneralComplex)
213+
{
214+
// setup the matrix
215+
const int dim = 15;
216+
std::vector<std::complex<double>> mat(dim * dim);
217+
set_rand(mat.data(), dim * dim);
218+
LR_Util::matsym(mat.data(), dim);
219+
Parallel_2D pmat;
220+
LR_Util::setup_2d_division(pmat, 1, dim, dim);
221+
std::vector<std::complex<double>> hmat_local(pmat.get_local_size());
222+
LR_Util::set_local_from_global(pmat, mat.data(), hmat_local.data());
223+
std::vector<std::complex<double>> smat_local(pmat.get_local_size(), 0.0);
224+
for (int lj = 0;lj < pmat.get_col_size();++lj)
225+
for (int li = 0;li < pmat.get_row_size();++li)
226+
if (pmat.local2global_row(li) == pmat.local2global_col(lj)) // diagonal elements
227+
smat_local[li * pmat.get_row_size() + lj] = std::complex<double>(1.0, 0.0);
228+
229+
// serial
230+
std::vector<double> eig(dim);
231+
LR_Util::diag_lapack(dim, mat.data(), eig.data());
232+
233+
// parallel
234+
std::vector<double> eig_para(dim);
235+
std::vector<std::complex<double>> eigvec_para(pmat.get_local_size());
236+
LR_Util::diag_scalapack(dim, hmat_local.data(), smat_local.data(), eig_para.data(), eigvec_para.data(), pmat.desc);
237+
238+
// compare
239+
check_double_eq(eig_para.data(), eig.data(), dim);
240+
std::vector<std::complex<double>> eigvec_serial_local(pmat.get_local_size());
241+
LR_Util::set_local_from_global(pmat, mat.data(), eigvec_serial_local.data());
242+
check_norm_eq(eigvec_para.data(), eigvec_serial_local.data(), pmat.get_local_size());
243+
}
244+
204245
TEST(LR_Util, DiagScaLapackComplex)
205246
{
206247
// setup the matrix

0 commit comments

Comments
 (0)