diff --git a/source/module_hamilt_general/operator.cpp b/source/module_hamilt_general/operator.cpp index a7cf8a6437..76efc16807 100644 --- a/source/module_hamilt_general/operator.cpp +++ b/source/module_hamilt_general/operator.cpp @@ -11,7 +11,8 @@ Operator::Operator(){} template Operator::~Operator() { - if(this->hpsi != nullptr) delete this->hpsi; + if(this->hpsi != nullptr) { delete this->hpsi; +} Operator* last = this->next_op; Operator* last_sub = this->next_sub_op; while(last != nullptr || last_sub != nullptr) @@ -69,7 +70,7 @@ typename Operator::hpsi_info Operator::hPsi(hpsi_info& inp op->act(psi_wrapper, *this->hpsi, nbands); break; default: - op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node); + op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ik_nbas(op->ik), is_first_node); break; } }; @@ -100,9 +101,11 @@ void Operator::init(const int ik_in) template void Operator::add(Operator* next) { - if(next==nullptr) return; + if(next==nullptr) { return; +} next->is_first_node = false; - if(next->next_op != nullptr) this->add(next->next_op); + if(next->next_op != nullptr) { this->add(next->next_op); +} Operator* last = this; //loop to end of the chain while(last->next_op != nullptr) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp index 858e6b3fd5..e8e5876913 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.cpp @@ -47,7 +47,7 @@ void Velocity::act ) const { ModuleBase::timer::tick("Operator", "Velocity"); - const int npw = psi_in->get_ngk(this->ik); + const int npw = psi_in->get_ik_nbas(this->ik); const int max_npw = psi_in->get_nbasis() / psi_in->npol; const int npol = psi_in->npol; const std::complex* tmpsi_in = psi0; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 5ec443ab4e..0a4c8e7b97 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -199,7 +199,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* if (base_device::get_device_type(ctx) == base_device::GpuDevice) { - psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, 1, psi_nc, &evc.get_ik_nbas(0)); T* ppsi = psi_temp.get_pointer(); // hpsi and spsi share the temp space T* temp = nullptr; @@ -212,7 +212,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* { // psi_temp is one band psi, psi is all bands psi, the range always is 1 for the only band in psi_temp syncmem_complex_op()(ctx, ctx, ppsi, psi + i * psi_nc, psi_nc); - psi::Range band_by_band_range(1, 0, 0, 0); + psi::Range band_by_band_range(true, 0, 0, 0); hpsi_info hpsi_in(&psi_temp, band_by_band_range, hpsi); // H|Psi> to get hpsi for target band @@ -246,7 +246,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* } else if (base_device::get_device_type(ctx) == base_device::CpuDevice) { - psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ngk(0)); + psi::Psi psi_temp(1, nstart, psi_nc, &evc.get_ik_nbas(0)); T* ppsi = psi_temp.get_pointer(); syncmem_complex_op()(ctx, ctx, ppsi, psi, psi_temp.size()); // hpsi and spsi share the temp space @@ -256,7 +256,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* T* hpsi = temp; // do hPsi for all bands - psi::Range all_bands_range(1, 0, 0, nstart - 1); + psi::Range all_bands_range(true, 0, 0, nstart - 1); hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); pHamilt->ops->hPsi(hpsi_in); @@ -586,8 +586,9 @@ bool DiagoIterAssist::test_exit_cond(const int& ntry, const int& notc //================================================================ bool scf = true; - if (PARAM.inp.calculation == "nscf") + if (PARAM.inp.calculation == "nscf") { scf = false; +} // If ntry <=5, try to do it better, if ntry > 5, exit. const bool f1 = (ntry <= 5); diff --git a/source/module_io/unk_overlap_pw.cpp b/source/module_io/unk_overlap_pw.cpp index d0d1d7c706..e05f2fe56c 100644 --- a/source/module_io/unk_overlap_pw.cpp +++ b/source/module_io/unk_overlap_pw.cpp @@ -29,12 +29,12 @@ std::complex unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw, ModuleBase::GlobalFunc::ZEROS(unk_R,number_pw); - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++) { unk_L[wfcpw->getigl2ig(ik_L,igl)] = evc[0](ik_L, iband_L, igl); } - for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) + for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++) { unk_R[wfcpw->getigl2ig(ik_R,igl)] = evc[0](ik_R, iband_R, igl); } @@ -103,7 +103,7 @@ std::complex unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw, // (3) calculate the overlap in ik_L and ik_R wfcpw->real2recip(psi_r, psi_r, ik_R); - for (int ig = 0; ig < evc->get_ngk(ik_R); ig++) + for (int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++) { result = result + conj(psi_r[ig]) * evc[0](ik_R, iband_R, ig); } @@ -143,12 +143,12 @@ std::complex unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf for(int i = 0; i < PARAM.globalv.npol; i++) { - for (int igl = 0; igl < evc->get_ngk(ik_L); igl++) + for (int igl = 0; igl < evc->get_ik_nbas(ik_L); igl++) { unk_L[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_L, iband_L, igl + i * npwx); } - for (int igl = 0; igl < evc->get_ngk(ik_R); igl++) + for (int igl = 0; igl < evc->get_ik_nbas(ik_R); igl++) { unk_R[wfcpw->getigl2ig(ik_L, igl) + i * number_pw] = evc[0](ik_R, iband_R, igl + i * npwx); } @@ -223,7 +223,7 @@ std::complex unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho for (int i = 0; i < PARAM.globalv.npol; i++) { - for(int ig = 0; ig < evc->get_ngk(ik_R); ig++) + for(int ig = 0; ig < evc->get_ik_nbas(ik_R); ig++) { if( i == 0 ) { result = result + conj( psi_up[ig] ) * evc[0](ik_R, iband_R, ig); } diff --git a/source/module_io/write_vxc_lip.hpp b/source/module_io/write_vxc_lip.hpp index 205fdbb057..8d1ff6a6df 100644 --- a/source/module_io/write_vxc_lip.hpp +++ b/source/module_io/write_vxc_lip.hpp @@ -148,7 +148,7 @@ namespace ModuleIO hpsi_localxc.fix_k(ik); #ifdef __DEBUG assert(hpsi_localxc.get_current_nbas() == psi_pw.get_current_nbas()); - assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ngk(ik)); + assert(hpsi_localxc.get_current_nbas() == hpsi_localxc.get_ik_nbas(ik)); #endif /// wrap psi and act band-by-band (the same result as act all bands at once) // for (int ib = 0;ib < psi_pw.get_nbands();++ib) @@ -156,9 +156,9 @@ namespace ModuleIO // std::cout<<"ib="< psi_single_band(&psi_pw(ik, ib, 0), 1, 1, psi_pw.get_current_nbas()); // psi::Psi hpsi_single_band(&hpsi_localxc(ik, ib, 0), 1, 1, hpsi_localxc.get_current_nbas()); - // vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ngk(ik)); + // vxcs_op_pw->act(1, psi_pw.get_current_nbas(), psi_pw.npol, psi_single_band.get_pointer(), hpsi_single_band.get_pointer(), psi_pw.get_ik_nbas(ik)); // } - vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ngk(ik)); + vxcs_op_pw->act(psi_pw.get_nbands(), psi_pw.get_nbasis(), psi_pw.npol, &psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_ik_nbas(ik)); delete vxcs_op_pw; std::vector vxc_local_k_mo = psi_Hpsi(&psi_pw(ik, 0, 0), &hpsi_localxc(ik, 0, 0), psi_pw.get_nbasis(), psi_pw.get_nbands()); Parallel_Reduce::reduce_pool(vxc_local_k_mo.data(), nbands * nbands); diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index f129d3e422..4d3efe2cfb 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -76,7 +76,6 @@ template Psi::Psi(T* psi_pointer, const this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - this->current_nbasis = nbs_in; this->psi_current = this->psi = psi_pointer; this->allocate_inside = false; // Currently only GPU's implementation is supported for device recording! @@ -148,7 +147,6 @@ template Psi::Psi(const Psi& psi_in) psi_in.get_pointer() - psi_in.get_psi_bias(), psi_in.size()); this->psi_bias = psi_in.get_psi_bias(); - this->current_nbasis = psi_in.get_current_nbas(); this->psi_current = this->psi + psi_in.get_psi_bias(); } @@ -200,7 +198,6 @@ Psi::Psi(const Psi& psi_in) psi_in.size()); } this->psi_bias = psi_in.get_psi_bias(); - this->current_nbasis = psi_in.get_current_nbas(); this->psi_current = this->psi + psi_in.get_psi_bias(); } @@ -213,7 +210,6 @@ void Psi::resize(const int nks_in, const int nbands_in, const int nba this->nk = nks_in; this->nbands = nbands_in; this->nbasis = nbasis_in; - this->current_nbasis = nbasis_in; this->psi_current = this->psi; // GlobalV::ofs_device << "allocated xxx MB memory for psi" << std::endl; } @@ -276,27 +272,26 @@ template std::size_t Psi::size() const template void Psi::fix_k(const int ik) const { - assert(ik >= 0); - this->current_k = ik; - if (this->ngk != nullptr && this->npol != 2) - this->current_nbasis = this->ngk[ik]; - else - this->current_nbasis = this->nbasis; + assert(ik >= 0 && ik < this->nk); - if (this->k_first)this->current_b = 0; - int base = this->current_b * this->nk * this->nbasis; - if (ik >= this->nk) + if (this->k_first == true) { - // mem_saver: fix to base - this->psi_bias = base; - this->psi_current = const_cast(&(this->psi[base])); + this->current_k = ik; + this->current_b = 0; + + this->psi_bias = this->current_k * this->nbands * this->nbasis; + this->psi_current = this->psi + this->psi_bias; } else { - this->psi_bias = k_first ? ik * this->nbands * this->nbasis : base + ik * this->nbasis; - this->psi_current = const_cast(&(this->psi[psi_bias])); + this->current_k = ik; + // this->current_b remains unchanged + + this->psi_bias = this->current_b * this->nk * this->nbasis + this->current_k * this->nbasis; + this->psi_current = this->psi + this->psi_bias; } } + template void Psi::fix_b(const int ib) const { assert(ib >= 0); @@ -366,15 +361,28 @@ template int Psi::get_current_b() const return this->current_b; } -template int Psi::get_current_nbas() const +template const int& Psi::get_current_nbas() const { - return this->current_nbasis; + if (this->ngk != nullptr) + { + return this->ngk[this->current_k]; + } + else + { + return this->nbasis; + } } -template const int& Psi::get_ngk(const int ik_in) const +template const int& Psi::get_ik_nbas(const int ik_in) const { - if (!this->ngk) return this->nbasis; - return this->ngk[ik_in]; + if (this->ngk != nullptr) + { + return this->ngk[ik_in]; + } + else + { + return this->nbasis; + } } template void Psi::zero_out() diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 283c641204..b7f39796ba 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -86,21 +86,27 @@ class Psi /// if k_first=true, ikb=ik, ikb2=iband /// if k_first=false, ikb=iband, ikb2=ik T& operator()(const int ikb1, const int ikb2, const int ibasis) const; + /// use operator "(ikb2, ibasis)" to reach target element for current k /// if k_first=true, ikb2=iband /// if k_first=false, ikb2=ik T& operator()(const int ikb2, const int ibasis) const; + // use operator "(ibasis)" to reach target element for current k and current band T& operator()(const int ibasis) const; // return current k-point index int get_current_k() const; + // return current band index int get_current_b() const; - // return current ngk for PW base - int get_current_nbas() const; - const int& get_ngk(const int ik_in) const; + // return the nbasis of current k + const int& get_current_nbas() const; + + // return the nbasis of ik_in + const int& get_ik_nbas(const int ik_in) const; + // return ngk array of psi const int* get_ngk_pointer() const; // return k_first diff --git a/source/module_psi/test/psi_test.cpp b/source/module_psi/test/psi_test.cpp index d031897e7e..b2b8787d55 100644 --- a/source/module_psi/test/psi_test.cpp +++ b/source/module_psi/test/psi_test.cpp @@ -63,23 +63,23 @@ TEST_F(TestPsi, get_val) EXPECT_EQ(psi_object14->get_psi_bias(), 0); } -TEST_F(TestPsi, get_ngk) +TEST_F(TestPsi, get_ik_nbas) { psi::Psi>* psi_object21 = new psi::Psi>(&ngk[0]); psi::Psi* psi_object22 = new psi::Psi(&ngk[0]); psi::Psi>* psi_object23 = new psi::Psi>(&ngk[0]); psi::Psi* psi_object24 = new psi::Psi(&ngk[0]); - EXPECT_EQ(psi_object21->get_ngk(2), ngk[2]); + EXPECT_EQ(psi_object21->get_ik_nbas(2), ngk[2]); EXPECT_EQ(psi_object21->get_ngk_pointer()[0], ngk[0]); - EXPECT_EQ(psi_object22->get_ngk(2), ngk[2]); + EXPECT_EQ(psi_object22->get_ik_nbas(2), ngk[2]); EXPECT_EQ(psi_object22->get_ngk_pointer()[0], ngk[0]); - EXPECT_EQ(psi_object23->get_ngk(2), ngk[2]); + EXPECT_EQ(psi_object23->get_ik_nbas(2), ngk[2]); EXPECT_EQ(psi_object23->get_ngk_pointer()[0], ngk[0]); - EXPECT_EQ(psi_object24->get_ngk(2), ngk[2]); + EXPECT_EQ(psi_object24->get_ik_nbas(2), ngk[2]); EXPECT_EQ(psi_object24->get_ngk_pointer()[0], ngk[0]); }