@@ -120,12 +120,15 @@ void DiagoCusolver<T>::diag_pool(hamilt::MatrixBlock<T>& h_mat,
120120{
121121 ModuleBase::TITLE (" DiagoCusolver" , " diag_pool" );
122122 ModuleBase::timer::tick (" DiagoCusolver" , " diag_pool" );
123- std::vector<double > eigen (PARAM.globalv .nlocal , 0.0 );
123+ const int nbands_local = psi.get_nbands ();
124+ const int nbasis = psi.get_nbasis ();
125+ int nbands_global = nbands_local;
126+ std::vector<double > eigen (nbasis, 0.0 );
124127 std::vector<T> eigenvectors (h_mat.row * h_mat.col );
125128 this ->dc .Dngvd (h_mat.row , h_mat.col , h_mat.p , s_mat.p , eigen.data (), eigenvectors.data ());
126- const int size = psi. get_nbands () * psi. get_nbasis () ;
129+ const int size = nbands_local * nbasis ;
127130 BlasConnector::copy (size, eigenvectors.data (), 1 , psi.get_pointer (), 1 );
128- BlasConnector::copy (PARAM. inp . nbands , eigen.data (), 1 , eigenvalue_in, 1 );
131+ BlasConnector::copy (nbands_global , eigen.data (), 1 , eigenvalue_in, 1 );
129132 ModuleBase::timer::tick (" DiagoCusolver" , " diag_pool" );
130133}
131134
@@ -140,7 +143,14 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
140143 hamilt::MatrixBlock<T> h_mat;
141144 hamilt::MatrixBlock<T> s_mat;
142145 phm_in->matrix (h_mat, s_mat);
143-
146+ const int nbands_local = psi.get_nbands ();
147+ const int nbasis = psi.get_nbasis ();
148+ int nbands_global;
149+ #ifdef __MPI
150+ MPI_Allreduce (&nbands_local, &nbands_global, 1 , MPI_INT, MPI_SUM, this ->ParaV ->comm ());
151+ #else
152+ nbands_global = nbands_local;
153+ #endif
144154#ifdef __MPI
145155 // global matrix
146156 Matrix_g<T> h_mat_g;
@@ -159,7 +169,7 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
159169#endif
160170
161171 // Allocate memory for eigenvalues
162- std::vector<double > eigen (PARAM. globalv . nlocal , 0.0 );
172+ std::vector<double > eigen (nbasis , 0.0 );
163173
164174 // Start the timer for the cusolver operation
165175 ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
@@ -189,31 +199,31 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
189199 MPI_Barrier (MPI_COMM_WORLD);
190200
191201 // broadcast eigenvalues to all processes
192- MPI_Bcast (eigen.data (), PARAM. inp . nbands , MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
202+ MPI_Bcast (eigen.data (), nbands_global , MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
193203
194204 // distribute psi to all processes
195205 distributePsi (this ->ParaV ->desc_wfc , psi.get_pointer (), psi_g.data ());
196206 }
197207 else
198208 {
199- // Be careful that h_mat.row * h_mat.col != psi.get_nbands() * psi.get_nbasis() under multi-k situation
209+ // Be careful that h_mat.row * h_mat.col != nbands * nbasis under multi-k situation
200210 std::vector<T> eigenvectors (h_mat.row * h_mat.col );
201211 this ->dc .Dngvd (h_mat.row , h_mat.col , h_mat.p , s_mat.p , eigen.data (), eigenvectors.data ());
202- const int size = psi. get_nbands () * psi. get_nbasis () ;
212+ const int size = nbands_local * nbasis ;
203213 BlasConnector::copy (size, eigenvectors.data (), 1 , psi.get_pointer (), 1 );
204214 }
205215#else
206216 std::vector<T> eigenvectors (h_mat.row * h_mat.col );
207217 this ->dc .Dngvd (h_mat.row , h_mat.col , h_mat.p , s_mat.p , eigen.data (), eigenvectors.data ());
208- const int size = psi. get_nbands () * psi. get_nbasis () ;
218+ const int size = nbands_local * nbasis ;
209219 BlasConnector::copy (size, eigenvectors.data (), 1 , psi.get_pointer (), 1 );
210220#endif
211221 // Stop the timer for the cusolver operation
212222 ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
213223
214224 // Copy the eigenvalues to the output arrays
215225 const int inc = 1 ;
216- BlasConnector::copy (PARAM. inp . nbands , eigen.data (), inc, eigenvalue_in, inc);
226+ BlasConnector::copy (nbands_global , eigen.data (), inc, eigenvalue_in, inc);
217227}
218228
219229// Explicit instantiation of the DiagoCusolver class for real and complex numbers
0 commit comments