Skip to content

Commit 48065a3

Browse files
authored
Refactor:Remove GloblaC::ucell in module_hsolver (#5657)
* change ucell in the hsolver_pw * change test hsolver_pw * Revert "change ucell in the hsolver_pw" This reverts commit 51ba921. * Revert "change test hsolver_pw" This reverts commit a815fdc. * use parameter trans instead of ucell
1 parent a43cbfb commit 48065a3

File tree

7 files changed

+50
-29
lines changed

7 files changed

+50
-29
lines changed

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ namespace ModuleESolver
147147
}
148148

149149
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
150-
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge);
150+
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);
151151

152152
// add exx
153153
#ifdef __EXX

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(UnitCell& ucell,
447447
this->pelec->ekb.c,
448448
GlobalV::RANK_IN_POOL,
449449
GlobalV::NPROC_IN_POOL,
450-
skip_charge);
450+
skip_charge,
451+
ucell.tpiba,
452+
ucell.nat);
451453

452454
Symmetry_rho srho;
453455
for (int is = 0; is < PARAM.inp.nspin; is++)

source/module_hsolver/hsolver_lcaopw.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace hsolver
2424

2525
#ifdef USE_PAW
2626
template <typename T>
27-
void HSolverLIP<T>::paw_func_in_kloop(const int ik)
27+
void HSolverLIP<T>::paw_func_in_kloop(const int ik,const double tpiba)
2828
{
2929
if (PARAM.inp.use_paw)
3030
{
@@ -64,7 +64,7 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
6464
this->wfc_basis->get_ig2iy(ik).data(),
6565
this->wfc_basis->get_ig2iz(ik).data(),
6666
(const double**)kpg,
67-
GlobalC::ucell.tpiba,
67+
tpiba,
6868
(const double**)gcar);
6969

7070
std::vector<double>().swap(kpt);
@@ -83,7 +83,10 @@ void HSolverLIP<T>::paw_func_in_kloop(const int ik)
8383
}
8484

8585
template <typename T>
86-
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes)
86+
void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi,
87+
elecstate::ElecState* pes,
88+
const double tpiba,
89+
const int nat)
8790
{
8891
if (PARAM.inp.use_paw)
8992
{
@@ -131,7 +134,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
131134
this->wfc_basis->get_ig2iy(ik).data(),
132135
this->wfc_basis->get_ig2iz(ik).data(),
133136
(const double**)kpg,
134-
GlobalC::ucell.tpiba,
137+
tpiba,
135138
(const double**)gcar);
136139

137140
std::vector<double>().swap(kpt);
@@ -164,7 +167,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
164167
{
165168
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);
166169

167-
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
170+
for (int iat = 0; iat < nat; iat++)
168171
{
169172
GlobalC::paw_cell.set_rhoij(iat,
170173
nrhoijsel[iat],
@@ -176,7 +179,7 @@ void HSolverLIP<T>::paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState*
176179
#else
177180
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);
178181

179-
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
182+
for (int iat = 0; iat < nat; iat++)
180183
{
181184
GlobalC::paw_cell.set_rhoij(iat,
182185
nrhoijsel[iat],
@@ -201,7 +204,9 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
201204
psi::Psi<T>& psi, // ESolver_KS_PW::kspw_psi
202205
elecstate::ElecState* pes, // ESolver_KS_PW::pes
203206
psi::Psi<T>& transform,
204-
const bool skip_charge)
207+
const bool skip_charge,
208+
const double tpiba,
209+
const int nat)
205210
{
206211
ModuleBase::TITLE("HSolverLIP", "solve");
207212
ModuleBase::timer::tick("HSolverLIP", "solve");
@@ -212,7 +217,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
212217
pHamilt->updateHk(ik);
213218

214219
#ifdef USE_PAW
215-
this->paw_func_in_kloop(ik);
220+
this->paw_func_in_kloop(ik,tpiba);
216221
#endif
217222

218223
psi.fix_k(ik);
@@ -282,7 +287,7 @@ void HSolverLIP<T>::solve(hamilt::Hamilt<T>* pHamilt, // ESolver_KS_PW::p_hamilt
282287
reinterpret_cast<elecstate::ElecStatePW<T>*>(pes)->psiToRho(psi);
283288

284289
#ifdef USE_PAW
285-
this->paw_func_after_kloop(psi, pes);
290+
this->paw_func_after_kloop(psi, pes,tpiba,nat);
286291
#endif
287292

288293
ModuleBase::timer::tick("HSolverLIP", "solve");

source/module_hsolver/hsolver_lcaopw.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,21 @@ class HSolverLIP
3131
psi::Psi<T>& psi,
3232
elecstate::ElecState* pes,
3333
psi::Psi<T>& transform,
34-
const bool skip_charge);
34+
const bool skip_charge,
35+
const double tpiba,
36+
const int nat);
3537

3638
private:
3739
ModulePW::PW_Basis_K* wfc_basis;
3840

3941
#ifdef USE_PAW
40-
void paw_func_in_kloop(const int ik);
42+
void paw_func_in_kloop(const int ik,
43+
const double tpiba);
4144

42-
void paw_func_after_kloop(psi::Psi<T>& psi, elecstate::ElecState* pes);
45+
void paw_func_after_kloop(psi::Psi<T>& psi,
46+
elecstate::ElecState* pes,
47+
const double tpiba,
48+
const int nat);
4349
#endif
4450
};
4551

source/module_hsolver/hsolver_pw.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ namespace hsolver
2828

2929
#ifdef USE_PAW
3030
template <typename T, typename Device>
31-
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
31+
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik,
32+
const double tpiba)
3233
{
3334
if (this->use_paw)
3435
{
@@ -68,7 +69,7 @@ void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
6869
this->wfc_basis->get_ig2iy(ik).data(),
6970
this->wfc_basis->get_ig2iz(ik).data(),
7071
(const double**)kpg,
71-
GlobalC::ucell.tpiba,
72+
tpiba,
7273
(const double**)gcar);
7374

7475
std::vector<double>().swap(kpt);
@@ -96,7 +97,10 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
9697
}
9798

9899
template <typename T, typename Device>
99-
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes)
100+
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi,
101+
elecstate::ElecState* pes,
102+
const double tpiba,
103+
const int nat)
100104
{
101105
if (this->use_paw)
102106
{
@@ -144,7 +148,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
144148
this->wfc_basis->get_ig2iy(ik).data(),
145149
this->wfc_basis->get_ig2iz(ik).data(),
146150
(const double**)kpg,
147-
GlobalC::ucell.tpiba,
151+
tpiba,
148152
(const double**)gcar);
149153

150154
std::vector<double>().swap(kpt);
@@ -177,7 +181,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
177181
{
178182
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);
179183

180-
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
184+
for (int iat = 0; iat < nat; iat++)
181185
{
182186
GlobalC::paw_cell.set_rhoij(iat,
183187
nrhoijsel[iat],
@@ -189,7 +193,7 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst
189193
#else
190194
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);
191195

192-
for (int iat = 0; iat < GlobalC::ucell.nat; iat++)
196+
for (int iat = 0; iat < nat; iat++)
193197
{
194198
GlobalC::paw_cell.set_rhoij(iat,
195199
nrhoijsel[iat],
@@ -255,7 +259,9 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
255259
double* out_eigenvalues,
256260
const int rank_in_pool_in,
257261
const int nproc_in_pool_in,
258-
const bool skip_charge)
262+
const bool skip_charge,
263+
const double tpiba,
264+
const int nat)
259265
{
260266
ModuleBase::TITLE("HSolverPW", "solve");
261267
ModuleBase::timer::tick("HSolverPW", "solve");
@@ -282,7 +288,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
282288
pHamilt->updateHk(ik);
283289

284290
#ifdef USE_PAW
285-
this->paw_func_in_kloop(ik);
291+
this->paw_func_in_kloop(ik,tpiba);
286292
#endif
287293

288294
/// update psi pointer for each k point
@@ -341,7 +347,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
341347
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);
342348

343349
#ifdef USE_PAW
344-
this->paw_func_after_kloop(psi, pes);
350+
this->paw_func_after_kloop(psi, pes,tpiba,nat);
345351
#endif
346352

347353
ModuleBase::timer::tick("HSolverPW", "solve");

source/module_hsolver/hsolver_pw.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class HSolverPW
4747
double* out_eigenvalues,
4848
const int rank_in_pool_in,
4949
const int nproc_in_pool_in,
50-
const bool skip_charge);
50+
const bool skip_charge,
51+
const double tpiba,
52+
const int nat);
5153

5254
protected:
5355
// diago caller
@@ -89,11 +91,12 @@ class HSolverPW
8991
std::vector<double> ethr_band;
9092

9193
#ifdef USE_PAW
92-
void paw_func_in_kloop(const int ik);
94+
void paw_func_in_kloop(const int ik,
95+
const double tpiba);
9396

9497
void call_paw_cell_set_currentk(const int ik);
9598

96-
void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
99+
void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes,const double tpiba,const int nat);
97100
#endif
98101
};
99102

source/module_hsolver/test/test_hsolver_pw.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {
249249
= hsolver::HSolverLIP<std::complex<float>>(&pwbk);
250250
hsolver::HSolverLIP<std::complex<double>> hs_d_lip
251251
= hsolver::HSolverLIP<std::complex<double>>(&pwbk);
252-
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,
253-
transform_test_cf, true);
252+
hs_f_lip.solve(&hamilt_test_f, psi_test_cf, &elecstate_test,transform_test_cf, true,0.0,0);
254253
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<float>>::avg_iter, 0.0);
255254
for (int i = 0; i < psi_test_cf.size(); i++)
256255
{
@@ -261,7 +260,7 @@ TEST_F(TestHSolverPW, SolveLcaoInPW) {
261260

262261
elecstate_test.ekb.c[0] = 1.0;
263262
elecstate_test.ekb.c[1] = 2.0;
264-
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true);
263+
hs_d_lip.solve(&hamilt_test_d, psi_test_cd, &elecstate_test, transform_test_cd, true,0.0,0);
265264
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter, 0.0);
266265
for (int i = 0; i < psi_test_cd.size(); i++)
267266
{

0 commit comments

Comments
 (0)