Skip to content

Commit 123a0d3

Browse files
committed
add scalapack solver and test
1 parent 76d023b commit 123a0d3

File tree

5 files changed

+150
-1
lines changed

5 files changed

+150
-1
lines changed

source/module_base/scalapack_connector.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,15 @@ extern "C"
9999
const float* abstol, int* m, int* nz, float* w, const float*orfac, std::complex<float>* Z, const int* iz, const int* jz, const int*descz,
100100
std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int*iwork, int*liwork, int* ifail, int*iclustr, float*gap, int* info);
101101

102+
void pdsyev_(const char* jobz, const char* uplo, const int* n,
103+
double* a, const int* ia, const int* ja, const int* desca,
104+
double* w, double* z, const int* iz, const int* jz, const int* descz,
105+
double* work, const int* lwork, int* info);
106+
107+
void pzheev_(const char* jobz, const char* uplo, const int* n,
108+
std::complex<double>* a, const int* ia, const int* ja, const int* desca,
109+
double* w, std::complex<double>* z, const int* iz, const int* jz, const int* descz,
110+
std::complex<double>* work, const int* lwork, double* rwork, const int* lrwork, int* info);
102111

103112
void pzgetri_(
104113
const int *n,

source/module_lr/utils/lr_util.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,52 @@ namespace LR_Util
180180
if (info) { std::cout << "ERROR: Lapack solver zgeev, info=" << info << std::endl; }
181181
}
182182

183+
#ifdef __MPI
184+
void diag_scalapack(const int& n, double* mat, double* eigval, double* eigvec, const int(&desc)[9])
185+
{
186+
ModuleBase::TITLE("LR_Util", "diag_scalapack<double>");
187+
char jobz = 'V', uplo = 'U';
188+
const int minus_one = -1;
189+
const int i1 = 1;
190+
int info = 0;
191+
double lwork_tmp = 0.0;
192+
pdsyev_(&jobz, &uplo, &n,
193+
mat, &i1, &i1, desc,
194+
eigval, eigvec, &i1, &i1, desc,
195+
&lwork_tmp, &minus_one, &info); // get the optimal size of work into lwork
196+
const int lwork = lwork_tmp;
197+
std::vector<double> work(lwork);
198+
pdsyev_(&jobz, &uplo, &n,
199+
mat, &i1, &i1, desc,
200+
eigval, eigvec, &i1, &i1, desc,
201+
work.data(), &lwork, &info);
202+
if (info) { std::cout << "ERROR: Scalapack solver, info=" << info << std::endl; }
203+
}
204+
void diag_scalapack(const int& n, std::complex<double>* mat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9])
205+
{
206+
ModuleBase::TITLE("LR_Util", "diag_lapack<complex<double>>");
207+
char jobz = 'V', uplo = 'U';
208+
const int minus_one = -1;
209+
const int i1 = 1;
210+
int info = 0;
211+
std::complex<double>lwork_tmp(0., 0.);
212+
double lrwork_tmp = 0.0;
213+
pzheev_(&jobz, &uplo, &n,
214+
mat, &i1, &i1, desc,
215+
eigval, eigvec, &i1, &i1, desc,
216+
&lwork_tmp, &minus_one, &lrwork_tmp, &minus_one, &info); // get the optimal workspace size
217+
const int lwork = lwork_tmp.real();
218+
const int lrwork = lrwork_tmp;
219+
std::vector<std::complex<double>> work(lwork);
220+
std::vector<double>rwork(lrwork);
221+
pzheev_(&jobz, &uplo, &n,
222+
mat, &i1, &i1, desc,
223+
eigval, eigvec, &i1, &i1, desc,
224+
work.data(), &lwork, rwork.data(), &lrwork, &info);
225+
if (info) { std::cout << "ERROR: Scalapack solver, info=" << info << std::endl; }
226+
}
227+
#endif
228+
183229
std::string tolower(const std::string& str)
184230
{
185231
std::string str_lower = str;

source/module_lr/utils/lr_util.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ namespace LR_Util
9494
/// the defination of row and col is consistent with setup_2d_division
9595
template <typename T>
9696
void gather_2d_to_full(const Parallel_2D& pv, const T* submat, T* fullmat, bool col_first, int global_nrow, int global_ncol);
97+
template <typename T>
98+
void set_local_from_global(const Parallel_2D& pv, const T* global, T* local);
9799
#endif
98100

99101
///=================diago-lapack====================
@@ -103,7 +105,10 @@ namespace LR_Util
103105
/// @brief diagonalize a general matrix
104106
void diag_lapack_nh(const int& n, double* mat, std::complex<double>* eig);
105107
void diag_lapack_nh(const int& n, std::complex<double>* mat, std::complex<double>* eig);
106-
108+
#ifdef __MPI
109+
void diag_scalapack(const int& n, double* mat, double* eigval, double* eigvec, const int(&desc)[9]);
110+
void diag_scalapack(const int& n, std::complex<double>* mat, double* eigval, std::complex<double>* eigvec, const int(&desc)[9]);
111+
#endif
107112
///=================string option====================
108113
std::string tolower(const std::string& str);
109114
std::string toupper(const std::string& str);

source/module_lr/utils/lr_util.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ namespace LR_Util
174174
//reduce to root
175175
MPI_Allreduce(MPI_IN_PLACE, fullmat, global_nrow * global_ncol, get_mpi_datatype(), MPI_SUM, pv.comm());
176176
};
177+
178+
template <typename T>
179+
void set_local_from_global(const Parallel_2D& pv, const T* global, T* local)
180+
{
181+
for (int c = 0;c < pv.get_col_size();++c)
182+
{
183+
for (int r = 0;r < pv.get_row_size();++r)
184+
{
185+
local[c * pv.get_row_size() + r] = global[pv.local2global_col(c) * pv.get_global_row_size() + pv.local2global_row(r)];
186+
}
187+
}
188+
}
177189
#endif
178190

179191
}

source/module_lr/utils/test/lr_util_algorithms_test.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@
33
#include "../lr_util.h"
44
#include "../lr_util_print.h"
55

6+
inline void check_double_eq(double* data1, double* data2, int size)
7+
{
8+
for (int i = 0;i < size;++i)
9+
EXPECT_NEAR(data1[i], data2[i], 1e-10);
10+
};
11+
inline void check_double_eq(std::complex<double>* data1, std::complex<double>* data2, int size)
12+
{
13+
for (int i = 0;i < size;++i)
14+
{
15+
EXPECT_NEAR(data1[i].real(), data2[i].real(), 1e-10);
16+
EXPECT_NEAR(data1[i].imag(), data2[i].imag(), 1e-10);
17+
}
18+
};
19+
inline void check_norm_eq(std::complex<double>* data1, std::complex<double>* data2, int size)
20+
{
21+
for (int i = 0;i < size;++i)
22+
{
23+
EXPECT_NEAR(std::norm(data1[i]), std::norm(data2[i]), 1e-10);
24+
}
25+
}
26+
627
TEST(LR_Util, PsiWrapper)
728
{
829
int nk = 2;
@@ -152,6 +173,62 @@ TEST(LR_Util, RWValue)
152173
for (int i = 0;i < vec2.size();++i) { EXPECT_EQ(vec2[i], vec[i]); };
153174
}
154175

176+
TEST(LR_Util, DiagScaLapackDouble)
177+
{
178+
// setup the matrix
179+
const int dim = 14;
180+
std::vector<double> mat(dim * dim);
181+
set_rand(mat.data(), dim * dim);
182+
LR_Util::matsym(mat.data(), dim);
183+
Parallel_2D pmat;
184+
LR_Util::setup_2d_division(pmat, 1, dim, dim);
185+
std::vector<double> mat_local(pmat.get_local_size(), 0.0);
186+
LR_Util::set_local_from_global(pmat, mat.data(), mat_local.data());
187+
188+
// serial
189+
std::vector<double> eig(dim);
190+
LR_Util::diag_lapack(dim, mat.data(), eig.data());
191+
192+
// parallel
193+
std::vector<double> eig_para(dim);
194+
std::vector<double> eigvec_para(pmat.get_local_size());
195+
LR_Util::diag_scalapack(dim, mat_local.data(), eig_para.data(), eigvec_para.data(), pmat.desc);
196+
197+
// compare
198+
check_double_eq(eig_para.data(), eig.data(), dim);
199+
std::vector<double> eigvec_serial_local(pmat.get_local_size());
200+
LR_Util::set_local_from_global(pmat, mat.data(), eigvec_serial_local.data());
201+
check_double_eq(eigvec_para.data(), eigvec_serial_local.data(), pmat.get_local_size());
202+
}
203+
204+
TEST(LR_Util, DiagScaLapackComplex)
205+
{
206+
// setup the matrix
207+
const int dim = 15;
208+
std::vector<std::complex<double>> mat(dim * dim);
209+
set_rand(mat.data(), dim * dim);
210+
LR_Util::matsym(mat.data(), dim);
211+
Parallel_2D pmat;
212+
LR_Util::setup_2d_division(pmat, 1, dim, dim);
213+
std::vector<std::complex<double>> mat_local(pmat.get_local_size(), 0.0);
214+
LR_Util::set_local_from_global(pmat, mat.data(), mat_local.data());
215+
216+
// serial
217+
std::vector<double> eig(dim);
218+
LR_Util::diag_lapack(dim, mat.data(), eig.data());
219+
220+
// parallel
221+
std::vector<double> eig_para(dim);
222+
std::vector<std::complex<double>> eigvec_para(pmat.get_local_size());
223+
LR_Util::diag_scalapack(dim, mat_local.data(), eig_para.data(), eigvec_para.data(), pmat.desc);
224+
225+
// compare
226+
check_double_eq(eig_para.data(), eig.data(), dim);
227+
std::vector<std::complex<double>> eigvec_serial_local(pmat.get_local_size());
228+
LR_Util::set_local_from_global(pmat, mat.data(), eigvec_serial_local.data());
229+
check_norm_eq(eigvec_para.data(), eigvec_serial_local.data(), pmat.get_local_size());
230+
}
231+
155232
int main(int argc, char** argv)
156233
{
157234
srand(time(NULL)); // for random number generator

0 commit comments

Comments
 (0)