@@ -184,7 +184,7 @@ namespace LR_Util
184184 void diag_scalapack (const int & n, double * mat, double * eigval, double * eigvec, const int (&desc)[9])
185185 {
186186 ModuleBase::TITLE (" LR_Util" , " diag_scalapack<double>" );
187- char jobz = ' V' , uplo = ' U' ;
187+ const char jobz = ' V' , uplo = ' U' ;
188188 const int minus_one = -1 ;
189189 const int i1 = 1 ;
190190 int info = 0 ;
@@ -204,24 +204,106 @@ namespace LR_Util
204204 void diag_scalapack (const int & n, std::complex <double >* mat, double * eigval, std::complex <double >* eigvec, const int (&desc)[9])
205205 {
206206 ModuleBase::TITLE (" LR_Util" , " diag_lapack<complex<double>>" );
207- char jobz = ' V' , uplo = ' U' ;
207+ const char jobz = ' V' , uplo = ' U' ;
208208 const int minus_one = -1 ;
209209 const int i1 = 1 ;
210210 int info = 0 ;
211- std::complex <double >lwork_tmp (0 ., 0 .);
212- double lrwork_tmp = 0.0 ;
213- pzheev_ (&jobz, &uplo, &n,
211+ std::vector<std::complex <double >> work (1 , 0.0 );
212+ std::vector<double >rwork (1 , 0.0 );
213+ // pzheev_(&jobz, &uplo, &n,
214+ // mat, &i1, &i1, desc,
215+ // eigval, eigvec, &i1, &i1, desc,
216+ // work.data(), &minus_one, rwork.data(), &minus_one, &info); // get the optimal workspace size
217+ // / try pzheevd
218+ // int liwork = 0;
219+ // pzheevd_(&jobz, &uplo, &n,
220+ // mat, &i1, &i1, desc,
221+ // eigval, eigvec, &i1, &i1, desc,
222+ // &lwork_tmp, &minus_one, &lrwork_tmp, &minus_one, &liwork, &minus_one, &info); // get the optimal workspace size
223+
224+ // try pzheevx
225+ const char range = ' A' ;
226+ const double zero = 0.0 ;
227+ double abstol = 0.0 ;
228+ int nz = n;
229+ std::vector<int > iwork (1 , 0 );
230+ std::vector<int > ifail (n, 0 );
231+ std::vector<int > iclustr (2 * GlobalV::DSIZE);
232+ std::vector<double > gap (GlobalV::DSIZE);
233+ pzheevx_ (&jobz, &range, &uplo, &n,
214234 mat, &i1, &i1, desc,
215- eigval, eigvec, &i1, &i1, desc,
216- &lwork_tmp, &minus_one, &lrwork_tmp, &minus_one, &info); // get the optimal workspace size
217- const int lwork = lwork_tmp.real ();
218- const int lrwork = lrwork_tmp;
219- std::vector<std::complex <double >> work (lwork);
220- std::vector<double >rwork (lrwork);
221- pzheev_ (&jobz, &uplo, &n,
235+ &zero, &zero, &i1, &i1, &zero,
236+ &nz, &nz, eigval, &zero,
237+ eigvec, &i1, &i1, desc,
238+ work.data (), &minus_one, rwork.data (), &minus_one, iwork.data (), &minus_one,
239+ ifail.data (), iclustr.data (), gap.data (), &info);
240+
241+ const int lwork = work.at (0 ).real ();
242+ work.resize (lwork);
243+ const int lrwork = rwork.at (0 );
244+ rwork.resize (lrwork);
245+ const int liwork = iwork.at (0 );
246+ iwork.resize (liwork);
247+ // std::cout << "pzheevx: query result: lwork=" << work.at(0) << ", lrwork=" << rwork.at(0) << ", liwork=" << iwork.at(0) << std::endl;
248+
249+ // pzheev_(&jobz, &uplo, &n,
250+ // mat, &i1, &i1, desc,
251+ // eigval, eigvec, &i1, &i1, desc,
252+ // work.data(), &lwork, rwork.data(), &lrwork, &info);
253+ // std::vector<int> iwork(liwork);
254+ // pzheevd_(&jobz, &uplo, &n,
255+ // mat, &i1, &i1, desc,
256+ // eigval, eigvec, &i1, &i1, desc,
257+ // work.data(), &lwork, rwork.data(), &lrwork, iwork.data(), &liwork, &info);
258+ pzheevx_ (&jobz, &range, &uplo, &n,
222259 mat, &i1, &i1, desc,
223- eigval, eigvec, &i1, &i1, desc,
224- work.data (), &lwork, rwork.data (), &lrwork, &info);
260+ &zero, &zero, &i1, &i1, &zero,
261+ &nz, &nz, eigval, &zero,
262+ eigvec, &i1, &i1, desc,
263+ work.data (), &lwork, rwork.data (), &lrwork, iwork.data (), &liwork,
264+ ifail.data (), iclustr.data (), gap.data (), &info);
265+ if (info) { std::cout << " ERROR: Scalapack solver, info=" << info << std::endl; }
266+ }
267+
268+ void diag_scalapack (const int & n, std::complex <double >* hmat, std::complex <double >* const smat, double * eigval, std::complex <double >* eigvec, const int (&desc)[9])
269+ {
270+ ModuleBase::TITLE (" LR_Util" , " diag_lapack<complex<double>>" );
271+ const char jobz = ' V' , uplo = ' U' , range = ' A' ;
272+ int minus_one = -1 ;
273+ const int i1 = 1 ;
274+ const double zero = 0.0 ;
275+ int info = 0 ;
276+ double abstol = 0.0 ;
277+ int nz = n;
278+ std::vector<std::complex <double >> work (1 , 0.0 );
279+ std::vector<double >rwork (1 , 0.0 );
280+ std::vector<int > iwork (1 , 0 );
281+ std::vector<int > ifail (n, 0 );
282+ std::vector<int > iclustr (2 * GlobalV::DSIZE);
283+ std::vector<double > gap (GlobalV::DSIZE);
284+ pzhegvx_ (&i1, &jobz, &range, &uplo, &n,
285+ hmat, &i1, &i1, desc, smat, &i1, &i1, desc,
286+ &zero, &zero, &i1, &i1, &zero,
287+ &nz, &nz, eigval, &zero,
288+ eigvec, &i1, &i1, desc,
289+ work.data (), &minus_one, rwork.data (), &minus_one, iwork.data (), &minus_one,
290+ ifail.data (), iclustr.data (), gap.data (), &info);
291+
292+ int lwork = work.at (0 ).real ();
293+ work.resize (lwork);
294+ int lrwork = rwork.at (0 );
295+ rwork.resize (lrwork);
296+ int liwork = iwork.at (0 );
297+ iwork.resize (liwork);
298+ // std::cout << "pzhegvx: query result: lwork=" << work.at(0) << ", lrwork=" << rwork.at(0) << ", liwork=" << iwork.at(0) << std::endl;
299+ pzhegvx_ (&i1, &jobz, &range, &uplo, &n,
300+ hmat, &i1, &i1, desc,
301+ smat, &i1, &i1, desc,
302+ &zero, &zero, &i1, &i1, &zero,
303+ &nz, &nz, eigval, &zero,
304+ eigvec, &i1, &i1, desc,
305+ work.data (), &lwork, rwork.data (), &lrwork, iwork.data (), &liwork,
306+ ifail.data (), iclustr.data (), gap.data (), &info);
225307 if (info) { std::cout << " ERROR: Scalapack solver, info=" << info << std::endl; }
226308 }
227309#endif
0 commit comments