Skip to content

Commit 51b4c88

Browse files
authored
Refactor: refactor npol, constructors and int *ngk of Psi class (#5863)
* change npol to private * fix cuda build bug * fix cuda build bug * fix build bug in cuda * remove npol value in psi * fix bug * fix bugs * update constructers * fix bug * fix test bug * remove Constructor 1-1 * fix bug * update psi * remove useless code
1 parent e927eab commit 51b4c88

File tree

32 files changed

+170
-182
lines changed

32 files changed

+170
-182
lines changed

source/module_elecstate/cal_dm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
8282
//dm.fix_k(ik);
8383
dm[ik].create(ParaV->ncol, ParaV->nrow);
8484
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
85-
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr);
85+
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), wfc.get_nbasis(), true);
8686
const std::complex<double>* pwfc = wfc.get_pointer();
8787
std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
8888
#ifdef _OPENMP

source/module_elecstate/elecstate_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
271271
{
272272
const T one{1, 0};
273273
const T zero{0, 0};
274-
const int npol = psi.npol;
274+
const int npol = psi.get_npol();
275275
const int npwx = psi.get_nbasis() / npol;
276276
const int nbands = psi.get_nbands() * npol;
277277
const int nkb = this->ppcell->nkb;

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ void ESolver_KS_LCAO_TDDFT::update_pot(UnitCell& ucell, const int istep, const i
196196
if (this->psi_laststep == nullptr)
197197
{
198198
#ifdef __MPI
199-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, nullptr);
199+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true);
200200
#else
201-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, nullptr);
201+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, kv.ngk, true);
202202
#endif
203203
}
204204

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ namespace ModuleESolver
9393
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
9494
delete this->psi_local;
9595
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
96-
this->p_psi_init->psi_initer->nbands_start(),
97-
this->psi->get_nbasis(),
98-
this->psi->get_ngk_pointer());
96+
this->p_psi_init->psi_initer->nbands_start(),
97+
this->psi->get_nbasis(),
98+
this->kv.ngk,
99+
true);
99100
#ifdef __EXX
100101
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
101102
|| PARAM.inp.calculation == "cell-relax"

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
212212
this->kv,
213213
this->ppcell,
214214
*this->pw_wfc);
215-
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
215+
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max);
216216
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
217217

218218
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input
7878
// 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
7979
size_t size = stowf.chi0->size();
8080
this->stowf.shchi
81-
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data());
81+
= new psi::Psi<T, Device>(this->kv.get_nks(),
82+
this->stowf.nchip_max,
83+
this->pw_wfc->npwk_max,
84+
this->kv.ngk,
85+
true);
8286
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T));
8387

8488
if (PARAM.inp.nbands > 0)
8589
{
8690
this->stowf.chiortho
87-
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data());
91+
= new psi::Psi<T, Device>(this->kv.get_nks(),
92+
this->stowf.nchip_max,
93+
this->pw_wfc->npwk_max,
94+
this->kv.ngk, true);
8895
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T));
8996
}
9097

source/module_esolver/lcao_before_scf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
159159
ncol = PARAM.inp.nbands;
160160
#endif
161161
}
162-
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, nullptr);
162+
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, this->kv.ngk, true);
163163
}
164164

165165
// init wfc from file

source/module_esolver/lcao_others.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
165165
ncol = PARAM.inp.nbands;
166166
#endif
167167
}
168-
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, nullptr);
168+
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, this->kv.ngk, true);
169169
}
170170

171171
// init wfc from file

source/module_hamilt_general/operator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
6363
delete this->hpsi;
6464
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
6565
1,
66-
nbands / psi_input->npol,
66+
nbands / psi_input->get_npol(),
6767
psi_input->get_nbasis(),
6868
psi_input->get_nbasis(),
6969
true);
@@ -86,7 +86,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
8686
default:
8787
op->act(nbands,
8888
psi_input->get_nbasis(),
89-
psi_input->npol,
89+
psi_input->get_npol(),
9090
tmpsi_in,
9191
this->hpsi->get_pointer(),
9292
psi_input->get_current_nbas(),
@@ -105,7 +105,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
105105
}
106106
ModuleBase::timer::tick("Operator", "hPsi");
107107

108-
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
108+
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->get_npol()), hpsi_pointer);
109109
}
110110

111111
template <typename T, typename Device>

source/module_hamilt_lcao/module_deltaspin/cal_mw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
6666
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_CPU>*>(this->psi);
6767
const int nbands = psi_t->get_nbands();
6868
const int nks = psi_t->get_nk();
69-
const int npol = psi_t->npol;
69+
const int npol = psi_t->get_npol();
7070
for(int ik = 0; ik < nks; ik++)
7171
{
7272
psi_t->fix_k(ik);
@@ -112,7 +112,7 @@ void spinconstrain::SpinConstrain<std::complex<double>>::cal_mi_pw()
112112
psi::Psi<std::complex<double>, base_device::DEVICE_GPU>* psi_t = static_cast<psi::Psi<std::complex<double>, base_device::DEVICE_GPU>*>(this->psi);
113113
const int nbands = psi_t->get_nbands();
114114
const int nks = psi_t->get_nk();
115-
const int npol = psi_t->npol;
115+
const int npol = psi_t->get_npol();
116116
for(int ik = 0; ik < nks; ik++)
117117
{
118118
psi_t->fix_k(ik);

0 commit comments

Comments
 (0)