Skip to content

Commit 1489157

Browse files
committed
update constructers
1 parent f1e2cca commit 1489157

21 files changed

+64
-48
lines changed

source/module_elecstate/cal_dm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
8282
//dm.fix_k(ik);
8383
dm[ik].create(ParaV->ncol, ParaV->nrow);
8484
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
85-
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr);
85+
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), wfc.get_nbasis(), true);
8686
const std::complex<double>* pwfc = wfc.get_pointer();
8787
std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
8888
#ifdef _OPENMP

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ void ESolver_KS_LCAO_TDDFT::update_pot(UnitCell& ucell, const int istep, const i
196196
if (this->psi_laststep == nullptr)
197197
{
198198
#ifdef __MPI
199-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, nullptr);
199+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true);
200200
#else
201-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, nullptr);
201+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, kv.ngk, true);
202202
#endif
203203
}
204204

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ namespace ModuleESolver
9393
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
9494
delete this->psi_local;
9595
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
96-
this->p_psi_init->psi_initer->nbands_start(),
97-
this->psi->get_nbasis(),
98-
this->psi->get_ngk_pointer());
96+
this->p_psi_init->psi_initer->nbands_start(),
97+
this->psi->get_nbasis(),
98+
this->kv.ngk,
99+
true);
99100
#ifdef __EXX
100101
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
101102
|| PARAM.inp.calculation == "cell-relax"

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
212212
this->kv,
213213
this->ppcell,
214214
*this->pw_wfc);
215-
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
215+
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.inp.nbands, this->pw_wfc->npwk_max);
216216
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
217217

218218
this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input
7878
// 4) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
7979
size_t size = stowf.chi0->size();
8080
this->stowf.shchi
81-
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data());
81+
= new psi::Psi<T, Device>(this->kv.get_nks(),
82+
this->stowf.nchip_max,
83+
this->pw_wfc->npwk_max,
84+
this->kv.ngk,
85+
true);
8286
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T));
8387

8488
if (PARAM.inp.nbands > 0)
8589
{
8690
this->stowf.chiortho
87-
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->pw_wfc->npwk_max, this->kv.ngk.data());
91+
= new psi::Psi<T, Device>(this->kv.get_nks(),
92+
this->stowf.nchip_max,
93+
this->pw_wfc->npwk_max,
94+
this->kv.ngk, true);
8895
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T));
8996
}
9097

source/module_esolver/lcao_before_scf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
159159
ncol = PARAM.inp.nbands;
160160
#endif
161161
}
162-
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, nullptr);
162+
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, this->kv.ngk, true);
163163
}
164164

165165
// init wfc from file

source/module_esolver/lcao_others.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
165165
ncol = PARAM.inp.nbands;
166166
#endif
167167
}
168-
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, nullptr);
168+
this->psi = new psi::Psi<TK>(nsk, ncol, this->pv.nrow, this->kv.ngk, true);
169169
}
170170

171171
// init wfc from file

source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ template <typename T, typename Device>
3535
void Stochastic_WF<T, Device>::init(K_Vectors* p_kv, const int npwx_in)
3636
{
3737
this->nks = p_kv->get_nks();
38-
this->ngk = p_kv->ngk.data();
38+
this->ngk = p_kv->ngk;
3939
this->npwx = npwx_in;
4040
nchip = new int[nks];
4141

@@ -111,7 +111,7 @@ void Stochastic_WF<T, Device>::allocate_chi0()
111111

112112
this->nchip_max = tmpnchip;
113113
size_t size = this->nchip_max * npwx * nks;
114-
this->chi0_cpu = new psi::Psi<T>(nks, this->nchip_max, npwx, this->ngk);
114+
this->chi0_cpu = new psi::Psi<T>(nks, this->nchip_max, npwx, this->ngk, true);
115115
ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(T));
116116

117117
for (int ik = 0; ik < nks; ++ik)
@@ -123,7 +123,7 @@ void Stochastic_WF<T, Device>::allocate_chi0()
123123
Device* ctx = {};
124124
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
125125
{
126-
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk);
126+
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk, true);
127127
}
128128
else
129129
{
@@ -207,7 +207,7 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
207207
delete[] npwip;
208208
}
209209
size_t size = this->nchip_max * npwx * nks;
210-
this->chi0_cpu = new psi::Psi<std::complex<double>>(nks, this->nchip_max, npwx, this->ngk);
210+
this->chi0_cpu = new psi::Psi<std::complex<double>>(nks, this->nchip_max, npwx, this->ngk, true);
211211
this->chi0_cpu->zero_out();
212212
ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex<double>));
213213
for (int ik = 0; ik < nks; ++ik)
@@ -252,7 +252,7 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
252252
Device* ctx = {};
253253
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
254254
{
255-
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk);
255+
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk, true);
256256
}
257257
else
258258
{
@@ -266,7 +266,7 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
266266
const int npwx = this->npwx;
267267
const int nks = this->nks;
268268
size_t size = this->nchip_max * npwx * nks;
269-
this->chi0_cpu = new psi::Psi<std::complex<double>>(nks, npwx, npwx, this->ngk);
269+
this->chi0_cpu = new psi::Psi<std::complex<double>>(nks, npwx, npwx, this->ngk, true);
270270
this->chi0_cpu->zero_out();
271271
ModuleBase::Memory::record("SDFT::chi0_cpu", size * sizeof(std::complex<double>));
272272
for (int ik = 0; ik < nks; ++ik)
@@ -284,7 +284,7 @@ void Stochastic_WF<T, Device>::init_com_orbitals()
284284
Device* ctx = {};
285285
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
286286
{
287-
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk);
287+
this->chi0 = new psi::Psi<T, Device>(nks, this->nchip_max, npwx, this->ngk, ture);
288288
}
289289
else
290290
{

source/module_hamilt_pw/hamilt_stodft/sto_wf.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ class Stochastic_WF
3030
int* nchip = nullptr; ///< The number of stochatic orbitals in current process of each k point.
3131
int nchip_max = 0; ///< Max number of stochastic orbitals among all k points.
3232
int nks = 0; ///< number of k-points
33-
int* ngk = nullptr; ///< ngk in klist
3433
int npwx = 0; ///< max ngk[ik] in all processors
3534
int nbands_diag = 0; ///< number of bands obtained from diagonalization
3635
int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag;
36+
std::vector<int> ngk; ///< ngk in klist
3737
public:
3838
// Tn(H)|chi>
3939
psi::Psi<T, Device>* chiallorder = nullptr;

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ void HSolverLCAO<T, Device>::parakSolve(hamilt::Hamilt<T>* pHamilt,
219219
k2d.distribute_hsk(pHamilt, ik_kpar, nrow);
220220
/// global index of k point
221221
int ik_global = ik + k2d.get_pKpoints()->startk_pool[k2d.get_my_pool()];
222-
auto psi_pool = psi::Psi<T>(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, nullptr);
222+
auto psi_pool = psi::Psi<T>(1, ncol_bands_pool, k2d.get_p2D_pool()->nrow, k2d.get_p2D_pool()->nrow, true);
223223
ModuleBase::Memory::record("HSolverLCAO::psi_pool", nrow * ncol_bands_pool * sizeof(T));
224224
if (ik_global < psi.get_nk() && ik < k2d.get_pKpoints()->nks_pool[k2d.get_my_pool()])
225225
{

0 commit comments

Comments
 (0)