Skip to content

Commit 9900bb7

Browse files
committed
fix bug
1 parent b15cd5c commit 9900bb7

File tree

8 files changed

+42
-14
lines changed

8 files changed

+42
-14
lines changed

source/module_elecstate/cal_dm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
2727
//dm.fix_k(ik);
2828
dm[ik].create(ParaV->ncol, ParaV->nrow);
2929
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
30-
// psi::Psi<double> wg_wfc(wfc, 1);
31-
psi::Psi<double> wg_wfc(1, nbands_local, nbasis_local);
30+
psi::Psi<double> wg_wfc(wfc, 1, nbands_local);
3231

3332
int ib_global = 0;
3433
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)

source/module_elecstate/module_dm/cal_dm_psi.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
3232
// dm.fix_k(ik);
3333
// dm[ik].create(ParaV->ncol, ParaV->nrow);
3434
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
35-
// psi::Psi<double> wg_wfc(wfc, 1, );
36-
psi::Psi<double> wg_wfc(1, nbands_local, nbasis_local);
35+
36+
psi::Psi<double> wg_wfc(wfc, 1, nbands_local);
3737

3838
int ib_global = 0;
3939
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)

source/module_hamilt_general/operator.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
156156
else if(hpsi_pointer == psi_pointer)
157157
{
158158
this->in_place = true;
159-
// this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
160-
this->hpsi = new psi::Psi<T, Device>(1, nbands_range, std::get<0>(info)->get_nbasis());
159+
this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
161160
}
162161
else
163162
{

source/module_io/get_pchg_lcao.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,8 @@ void IState_Charge::idmatrix(const int& ib,
478478

479479
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
480480
this->psi_gamma->fix_k(is);
481-
// psi::Psi<double> wg_wfc(*this->psi_gamma, 1);
482-
psi::Psi<double> wg_wfc(1, this->psi_gamma->get_nbands(), this->psi_gamma->get_nbasis());
481+
482+
psi::Psi<double> wg_wfc(*this->psi_gamma, 1, this->psi_gamma->get_nbands());
483483

484484
for (int ir = 0; ir < wg_wfc.get_nbands(); ++ir)
485485
{

source/module_io/write_dos_lcao.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,7 @@ void ModuleIO::write_dos_lcao(const UnitCell& ucell,
461461
}
462462

463463
psi->fix_k(ik);
464-
// psi::Psi<std::complex<double>> Dwfc(psi[0], 1);
465-
psi::Psi<std::complex<double>> Dwfc(1, psi->get_nbands(), psi->get_nbasis());
464+
psi::Psi<std::complex<double>> Dwfc(*psi, 1, psi->get_nbands());
466465

467466
std::complex<double>* p_dwfc = Dwfc.get_pointer();
468467
for (int index = 0; index < Dwfc.size(); ++index)

source/module_io/write_proj_band_lcao.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ void ModuleIO::write_proj_band_lcao(
221221

222222
// calculate Mulk
223223
psi->fix_k(ik);
224-
// psi::Psi<std::complex<double>> Dwfc(psi[0], 1);
225-
psi::Psi<std::complex<double>> Dwfc(1, psi->get_nbands(), psi->get_nbasis());
224+
psi::Psi<std::complex<double>> Dwfc(psi[0], 1, psi->get_nbands());
226225

227226
std::complex<double>* p_dwfc = Dwfc.get_pointer();
228227
for (int index = 0; index < Dwfc.size(); ++index)

source/module_psi/psi.cpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,34 @@ Psi<T, Device>::Psi(T* psi_pointer, const int nk_in, const int nbd_in, const int
107107
base_device::information::print_device_info<Device>(this->ctx, GlobalV::ofs_device);
108108
}
109109

110+
111+
template <typename T, typename Device>
112+
Psi<T, Device>::Psi(const Psi& psi_in, const int nk_in, int nband_in)
113+
{
114+
assert(nk_in <= psi_in.get_nk());
115+
if (nband_in == 0)
116+
{
117+
nband_in = psi_in.get_nbands();
118+
}
119+
this->k_first = psi_in.get_k_first();
120+
this->device = psi_in.device;
121+
this->resize(nk_in, nband_in, psi_in.get_nbasis());
122+
this->ngk = psi_in.ngk;
123+
this->npol = psi_in.npol;
124+
if (nband_in <= psi_in.get_nbands())
125+
{
126+
// copy from Psi from psi_in(current_k, 0, 0),
127+
// if size of k is 1, current_k in new Psi is psi_in.current_k
128+
if (nk_in == 1)
129+
{
130+
// current_k for this Psi only keep the spin index same as the copied Psi
131+
this->current_k = psi_in.get_current_k();
132+
}
133+
synchronize_memory_op()(this->ctx, psi_in.get_device(), this->psi, psi_in.get_pointer(), this->size());
134+
}
135+
}
136+
137+
110138
template <typename T, typename Device>
111139
Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in)
112140
{
@@ -208,11 +236,11 @@ template <typename T, typename Device>
208236
void Psi<T, Device>::resize(const int nks_in, const int nbands_in, const int nbasis_in)
209237
{
210238
assert(nks_in > 0 && nbands_in >= 0 && nbasis_in > 0);
211-
239+
212240
// This function will delete the psi array first(if psi exist), then malloc a new memory for it.
213241
resize_memory_op()(this->ctx, this->psi, nks_in * static_cast<std::size_t>(nbands_in) * nbasis_in, "no_record");
214242

215-
this->zero_out();
243+
// this->zero_out();
216244

217245
this->nk = nks_in;
218246
this->nbands = nbands_in;

source/module_psi/psi.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class Psi
4242
// Constructor 3: specify nk, nbands, nbasis, ngk, and do not need to call resize() later
4343
Psi(const int nk_in, const int nbd_in, const int nbs_in, const int* ngk_in = nullptr, const bool k_first_in = true);
4444

45+
// Constructor 4: copy a new Psi which have several k-points and several bands from inputted psi_in
46+
Psi(const Psi& psi_in, const int nk_in, int nband_in = 0);
47+
48+
4549
// Constructor 5: a wrapper of a data pointer, used for Operator::hPsi()
4650
// in this case, fix_k can not be used
4751
Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in = 0);

0 commit comments

Comments
 (0)