Skip to content

Commit e23f43e

Browse files
committed
remove global parameters
1 parent b5815c6 commit e23f43e

File tree

7 files changed

+58
-47
lines changed

7 files changed

+58
-47
lines changed

source/module_cell/parallel_kpoints.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
// the kpoints here are reduced after symmetry applied.
77
void Parallel_Kpoints::kinfo(int& nkstot_in,
8-
const int& kpar_in,
9-
const int& my_pool_in,
10-
const int& rank_in_pool_in,
11-
const int& nproc_in,
12-
const int& nspin_in)
8+
const int kpar_in,
9+
const int my_pool_in,
10+
const int rank_in_pool_in,
11+
const int nproc_in,
12+
const int nspin_in)
1313
{
1414
#ifdef __MPI
1515

source/module_cell/parallel_kpoints.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ class Parallel_Kpoints
1313
~Parallel_Kpoints(){};
1414

1515
void kinfo(int& nkstot_in,
16-
const int& kpar_in,
17-
const int& my_pool_in,
18-
const int& rank_in_pool_in,
19-
const int& nproc_in,
20-
const int& nspin_in);
16+
const int kpar_in,
17+
const int my_pool_in,
18+
const int rank_in_pool_in,
19+
const int nproc_in,
20+
const int nspin_in);
2121

2222
// collect value from each pool to wk.
2323
void pool_collection(double& value, const double* wk, const int& ik);

source/module_hsolver/diago_cusolver.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,15 @@ void DiagoCusolver<T>::diag_pool(hamilt::MatrixBlock<T>& h_mat,
120120
{
121121
ModuleBase::TITLE("DiagoCusolver", "diag_pool");
122122
ModuleBase::timer::tick("DiagoCusolver", "diag_pool");
123-
std::vector<double> eigen(PARAM.globalv.nlocal, 0.0);
123+
const int nbands_local = psi.get_nbands();
124+
const int nbasis = psi.get_nbasis();
125+
int nbands_global = nbands_local;
126+
std::vector<double> eigen(nbasis, 0.0);
124127
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
125128
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
126-
const int size = psi.get_nbands() * psi.get_nbasis();
129+
const int size = nbands_local * nbasis;
127130
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
128-
BlasConnector::copy(PARAM.inp.nbands, eigen.data(), 1, eigenvalue_in, 1);
131+
BlasConnector::copy(nbands_global, eigen.data(), 1, eigenvalue_in, 1);
129132
ModuleBase::timer::tick("DiagoCusolver", "diag_pool");
130133
}
131134

@@ -140,7 +143,14 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
140143
hamilt::MatrixBlock<T> h_mat;
141144
hamilt::MatrixBlock<T> s_mat;
142145
phm_in->matrix(h_mat, s_mat);
143-
146+
const int nbands_local = psi.get_nbands();
147+
const int nbasis = psi.get_nbasis();
148+
int nbands_global;
149+
#ifdef __MPI
150+
MPI_Allreduce(&nbands_local, &nbands_global, 1, MPI_INT, MPI_SUM, this->ParaV->comm());
151+
#else
152+
nbands_global = nbands_local;
153+
#endif
144154
#ifdef __MPI
145155
// global matrix
146156
Matrix_g<T> h_mat_g;
@@ -159,7 +169,7 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
159169
#endif
160170

161171
// Allocate memory for eigenvalues
162-
std::vector<double> eigen(PARAM.globalv.nlocal, 0.0);
172+
std::vector<double> eigen(nbasis, 0.0);
163173

164174
// Start the timer for the cusolver operation
165175
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
@@ -189,31 +199,31 @@ void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* e
189199
MPI_Barrier(MPI_COMM_WORLD);
190200

191201
// broadcast eigenvalues to all processes
192-
MPI_Bcast(eigen.data(), PARAM.inp.nbands, MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
202+
MPI_Bcast(eigen.data(), nbands_global, MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
193203

194204
// distribute psi to all processes
195205
distributePsi(this->ParaV->desc_wfc, psi.get_pointer(), psi_g.data());
196206
}
197207
else
198208
{
199-
// Be careful that h_mat.row * h_mat.col != psi.get_nbands() * psi.get_nbasis() under multi-k situation
209+
// Be careful that h_mat.row * h_mat.col != nbands * nbasis under multi-k situation
200210
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
201211
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
202-
const int size = psi.get_nbands() * psi.get_nbasis();
212+
const int size = nbands_local * nbasis;
203213
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
204214
}
205215
#else
206216
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
207217
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
208-
const int size = psi.get_nbands() * psi.get_nbasis();
218+
const int size = nbands_local * nbasis;
209219
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
210220
#endif
211221
// Stop the timer for the cusolver operation
212222
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
213223

214224
// Copy the eigenvalues to the output arrays
215225
const int inc = 1;
216-
BlasConnector::copy(PARAM.inp.nbands, eigen.data(), inc, eigenvalue_in, inc);
226+
BlasConnector::copy(nbands_global, eigen.data(), inc, eigenvalue_in, inc);
217227
}
218228

219229
// Explicit instantiation of the DiagoCusolver class for real and complex numbers

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#ifdef __CUDA
2020
#include "diago_cusolver.h"
21+
#include "module_base/module_device/device.h"
2122
#endif
2223

2324
#ifdef __PEXSI
@@ -186,6 +187,9 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
186187
{
187188
#ifdef __MPI
188189
ModuleBase::timer::tick("HSolverLCAO", "parakSolve");
190+
#ifdef __CUDA
191+
base_device::information::set_device_by_rank();
192+
#endif
189193
auto k2d = Parallel_K2D<T>();
190194
k2d.set_kpar(kpar);
191195
int nbands = this->ParaV->get_nbands();
@@ -194,10 +198,10 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
194198
int nb2d = this->ParaV->get_block_size();
195199
if(this->method == "cusolver")
196200
{
197-
k2d.set_para_env_cusolver(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin);
201+
k2d.set_para_env_cusolver(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK);
198202
} else
199203
{
200-
k2d.set_para_env(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin);
204+
k2d.set_para_env(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK);
201205
}
202206
/// set psi_pool
203207
const int zero = 0;
@@ -272,7 +276,7 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
272276
#ifdef __CUDA
273277
else if (this->method == "cusolver")
274278
{
275-
DiagoCusolver<T> cs(nullptr);
279+
DiagoCusolver<T> cs;
276280
cs.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D);
277281
}
278282
#endif

source/module_hsolver/kernels/cuda/diag_cusolver.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
#include <assert.h>
22
#include "diag_cusolver.cuh"
33
#include "helper_cuda.h"
4-
#include "module_base/module_device/device.h"
54

65
Diag_Cusolver_gvd::Diag_Cusolver_gvd(){
76
// step 1: create cusolver/cublas handle
8-
#if defined(__MPI) && defined(__CUDA)
9-
base_device::information::set_device_by_rank();
10-
#endif
117
cusolverH = NULL;
128
checkCudaErrors( cusolverDnCreate(&cusolverH) );
139

source/module_hsolver/parallel_k2d.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
template <typename TK>
99
void Parallel_K2D<TK>::set_para_env(int nks,
10-
const int& nw,
11-
const int& nb2d,
12-
const int& nproc,
13-
const int& my_rank,
14-
const int& nspin) {
10+
const int nw,
11+
const int nb2d,
12+
const int nproc,
13+
const int my_rank,
14+
const int nspin) {
1515
const int kpar = this->get_kpar();
1616
Parallel_Global::divide_mpi_groups(nproc,
1717
kpar,
@@ -36,11 +36,11 @@ void Parallel_K2D<TK>::set_para_env(int nks,
3636

3737
template <typename TK>
3838
void Parallel_K2D<TK>::set_para_env_cusolver(int nks,
39-
const int& nw,
40-
const int& nb2d,
41-
const int& nproc,
42-
const int& my_rank,
43-
const int& nspin) {
39+
const int nw,
40+
const int nb2d,
41+
const int nproc,
42+
const int my_rank,
43+
const int nspin) {
4444
const int kpar = this->get_kpar();
4545
if(kpar <= 0 || kpar > nproc)
4646
{

source/module_hsolver/parallel_k2d.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,20 @@ class Parallel_K2D {
2626
*/
2727
/// this function sets the parallel environment for k-points parallelism
2828
/// including the glabal and pool 2D parallel distribution
29+
/// nspin doesn't affect anything here, in fact it can be deleted
2930
void set_para_env(int nks,
30-
const int& nw,
31-
const int& nb2d,
32-
const int& nproc,
33-
const int& my_rank,
34-
const int& nspin);
31+
const int nw,
32+
const int nb2d,
33+
const int nproc,
34+
const int my_rank,
35+
const int nspin = 1);
3536

3637
void set_para_env_cusolver(int nks,
37-
const int& nw,
38-
const int& nb2d,
39-
const int& nproc,
40-
const int& my_rank,
41-
const int& nspin);
38+
const int nw,
39+
const int nb2d,
40+
const int nproc,
41+
const int my_rank,
42+
const int nspin = 1);
4243

4344
/// this function distributes the Hk and Sk matrices to hk_pool and sk_pool
4445
void distribute_hsk(hamilt::Hamilt<TK>* pHamilt,

0 commit comments

Comments
 (0)