99#include " elpa_new.h"
1010#include " elpa_solver.h"
1111
12- #include " my_math.hpp "
12+ #include " module_base/scalapack_connector.h "
1313#include " utils.h"
1414
1515extern std::map<int , elpa_t > NEW_ELPA_HANDLE_POOL;
@@ -72,7 +72,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
7272 t=-1 ;
7373 timer (myid, " A*U^-1" , " 2.1a" , t);
7474 }
75- Cpzgemm (' C' , ' N' , nFull, 1.0 , A, B, 0.0 , zwork.data (), desc);
75+ ScalapackConnector::gemm (' C' , ' N' , nFull, nFull , nFull, 1.0 , A, B, 0.0 , zwork.data (), desc);
7676 if (loglevel>1 )
7777 {
7878 timer (myid, " A*U^-1" , " 2.1a" , t);
@@ -84,7 +84,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
8484 t=-1 ;
8585 timer (myid, " U^-T*(A*U^-1)" , " 2.2a" , t);
8686 }
87- Cpzgemm (' C' , ' N' , nFull, 1.0 , B, zwork.data (), 0.0 , A, desc);
87+ ScalapackConnector::gemm (' C' , ' N' , nFull, nFull , nFull, 1.0 , B, zwork.data (), 0.0 , A, desc);
8888 if (loglevel>1 )
8989 {
9090 timer (myid, " U^-T*(A*U^-1)" , " 2.2a" , t);
@@ -98,7 +98,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
9898 t=-1 ;
9999 timer (myid, " B*A^T" , " 2.1b" , t);
100100 }
101- Cpzgemm (' N' , ' C' , nFull, 1.0 , B, A, 0.0 , zwork.data (), desc);
101+ ScalapackConnector::gemm (' N' , ' C' , nFull, nFull , nFull, 1.0 , B, A, 0.0 , zwork.data (), desc);
102102 if (loglevel>1 )
103103 {
104104 timer (myid, " B*A^T" , " 2.1b" , t);
@@ -109,7 +109,7 @@ int ELPA_Solver::generalized_eigenvector(std::complex<double>* A, std::complex<d
109109 t=-1 ;
110110 timer (myid, " B*(B*A^T)^T" , " 2.2b" , t);
111111 }
112- Cpzgemm (' N' , ' C' , nFull, 1.0 , B, zwork.data (), 0.0 , A, desc);
112+ ScalapackConnector::gemm (' N' , ' C' , nFull, nFull , nFull, 1.0 , B, zwork.data (), 0.0 , A, desc);
113113 if (loglevel>1 )
114114 {
115115 timer (myid, " B*(B*A^T)^T" , " 2.2b" , t);
@@ -168,7 +168,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
168168 t=-1 ;
169169 timer (myid, " pzpotrf_" , " 1" , t);
170170 }
171- Cpzpotrf (' U' , nFull, B, desc);
171+ ScalapackConnector::potrf (' U' , nFull, B, desc);
172172 if (loglevel>1 )
173173 {
174174 timer (myid, " pzpotrf_" , " 1" , t);
@@ -214,7 +214,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
214214 t=-1 ;
215215 timer (myid, " pzpotrf_" , " 2" , t);
216216 }
217- Cpzpotrf (' U' , nFull, B, desc);
217+ ScalapackConnector::potrf (' U' , nFull, B, desc);
218218 if (loglevel>1 )
219219 {
220220 timer (myid, " pzpotrf_" , " 2" , t);
@@ -290,7 +290,7 @@ int ELPA_Solver::decomposeRightMatrix(std::complex<double>* B, double* EigenValu
290290 t=-1 ;
291291 timer (myid, " qevq=qev*q^T" , " 2" , t);
292292 }
293- Cpzgemm (' N' , ' C' , nFull, 1.0 , zwork.data (), EigenVector, 0.0 , B, desc);
293+ ScalapackConnector::gemm (' N' , ' C' , nFull, nFull , nFull, 1.0 , zwork.data (), EigenVector, 0.0 , B, desc);
294294 if (loglevel>1 )
295295 {
296296 timer (myid, " qevq=qev*q^T" , " 2" , t);
@@ -310,7 +310,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
310310 t=-1 ;
311311 timer (myid, " Cpztrmm" , " 1" , t);
312312 }
313- Cpztrmm (' L' , ' U' , ' N' , ' N' , nFull, nev, 1.0 , B, EigenVector, desc);
313+ ScalapackConnector::trmm (' L' , ' U' , ' N' , ' N' , nFull, nev, 1.0 , B, EigenVector, desc);
314314 if (loglevel>1 )
315315 {
316316 timer (myid, " Cpztrmm" , " 1" , t);
@@ -322,7 +322,7 @@ int ELPA_Solver::composeEigenVector(int DecomposedState, std::complex<double>* B
322322 t=-1 ;
323323 timer (myid, " Cpzgemm" , " 1" , t);
324324 }
325- Cpzgemm (' C' , ' N' , nFull, nev, nFull, 1.0 , B, zwork.data (), 0.0 , EigenVector, desc);
325+ ScalapackConnector::gemm (' C' , ' N' , nFull, nev, nFull, 1.0 , B, zwork.data (), 0.0 , EigenVector, desc);
326326 if (loglevel>1 )
327327 {
328328 timer (myid, " Cpzgemm" , " 1" , t);
@@ -368,19 +368,19 @@ void ELPA_Solver::verify(std::complex<double>* A, double* EigenValue, std::compl
368368 }
369369
370370 // R=V*D
371- Cpzhemm (' R' , ' U' , nFull, 1.0 , D, V, 0.0 , R, desc);
371+ ScalapackConnector::hemm (' R' , ' U' , nFull, 1.0 , D, V, 0.0 , R, desc);
372372 if (loglevel>2 ) saveMatrix (" VD.dat" , nFull, R, desc, cblacs_ctxt);
373373 // R=A*V-V*D=A*V-R
374- Cpzhemm (' L' , ' U' , nFull, 1.0 , A, V, -1.0 , R, desc);
374+ ScalapackConnector::hemm (' L' , ' U' , nFull, 1.0 , A, V, -1.0 , R, desc);
375375 if (loglevel>2 ) saveMatrix (" AV-VD.dat" , nFull, R, desc, cblacs_ctxt);
376376 // calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
377377 double sumError=0 ;
378378 maxError=0 ;
379379 for (int i=1 ; i<=nev; ++i)
380380 {
381381 std::complex <double > E;
382- Cpzdotc (nFull, E, R, 1 , i, 1 ,
383- R, 1 , i, 1 , desc);
382+ ScalapackConnector::dot (nFull, E, R, 1 , i, 1 ,
383+ R, 1 , i, 1 , desc);
384384 double abs_E=std::abs (E);
385385 sumError+=abs_E;
386386 maxError=std::max (maxError, abs_E);
@@ -427,22 +427,22 @@ void ELPA_Solver::verify(std::complex<double>* A, std::complex<double>* B,
427427 }
428428
429429 // zwork=B*V
430- Cpzhemm (' L' , ' U' , nFull, 1.0 , B, V, 0.0 , zwork.data (), desc);
430+ ScalapackConnector::hemm (' L' , ' U' , nFull, 1.0 , B, V, 0.0 , zwork.data (), desc);
431431 if (loglevel>2 ) saveMatrix (" BV.dat" , nFull, zwork.data (), desc, cblacs_ctxt);
432432 // R=B*V*D=zwork*D
433- Cpzhemm ( ' R ' , ' U' , nFull, 1.0 , D, zwork. data () , 0.0 , R , desc);
433+ ScalapackConnector::hemm ( ' L ' , ' U' , nFull, 1.0 , B, V , 0.0 , zwork. data () , desc);
434434 if (loglevel>2 ) saveMatrix (" BVD.dat" , nFull, R, desc, cblacs_ctxt);
435435 // R=A*V-B*V*D=A*V-R
436- Cpzhemm (' L' , ' U' , nFull, 1.0 , A , V, - 1 .0 , R , desc);
436+ ScalapackConnector::hemm (' L' , ' U' , nFull, 1.0 , B , V, 0 .0 , zwork. data () , desc);
437437 if (loglevel>2 ) saveMatrix (" AV-BVD.dat" , nFull, R, desc, cblacs_ctxt);
438438 // calculate the maximum and mean value of sum_i{R(:,i)*R(:,i)}
439439 double sumError=0 ;
440440 maxError=0 ;
441441 for (int i=1 ; i<=nev; ++i)
442442 {
443443 std::complex <double > E;
444- Cpzdotc (nFull, E, R, 1 , i, 1 ,
445- R, 1 , i, 1 , desc);
444+ ScalapackConnector::dot (nFull, E, R, 1 , i, 1 ,
445+ R, 1 , i, 1 , desc);
446446 double abs_E=std::abs (E);
447447 sumError+=abs_E;
448448 maxError=std::max (maxError, abs_E);
0 commit comments