diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 3d6a993bb2..b0e64604b1 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -376,13 +376,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->kspw_psi, this->p_hamilt, this->ppcell, - GlobalV::ofs_running, - this->already_initpsi); - - if (this->already_initpsi == false) - { - this->already_initpsi = true; - } + GlobalV::ofs_running); } } diff --git a/source/module_esolver/esolver_ks_pw.h b/source/module_esolver/esolver_ks_pw.h index da5a7dcc22..860f29f684 100644 --- a/source/module_esolver/esolver_ks_pw.h +++ b/source/module_esolver/esolver_ks_pw.h @@ -63,8 +63,6 @@ class ESolver_KS_PW : public ESolver_KS psi::Psi, Device>* __kspw_psi = nullptr; - bool already_initpsi = false; - using castmem_2d_d2h_op = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 2eec4fa6ca..34a17de4f7 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -28,30 +28,6 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt* pHamilt, const int nbands = psi.get_nbands(); const int nks = psi.get_nk(); - //--------------------------------------------------------------------------------------------------------------- - //---------------------------------for psi init guess!!!!-------------------------------------------------------- - //--------------------------------------------------------------------------------------------------------------- - // if (!PARAM.inp.psi_initializer && !this->initialed_psi && this->basis_type == "pw") - // { - // for (int ik = 0; ik < nks; ++ik) - // { - // /// update H(k) for each k point - // pHamilt->updateHk(ik); - - // if (nbands > 0 && GlobalV::MY_STOGROUP == 0) - // { - // /// update psi pointer for each k point - // psi.fix_k(ik); - - // /// for psi init guess!!!! - // hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt); - // } - // } - // } - //--------------------------------------------------------------------------------------------------------------- - //--------------------------------------------------------------------------------------------------------------- - //--------------------------------------------------------------------------------------------------------------- - // prepare for the precondition of diagonalization std::vector precondition(psi.get_nbasis(), 0.0); diff --git a/source/module_psi/psi_init.cpp b/source/module_psi/psi_init.cpp index fdd08958ba..51debdc47d 100644 --- a/source/module_psi/psi_init.cpp +++ b/source/module_psi/psi_init.cpp @@ -152,8 +152,7 @@ void PSIInit::initialize_psi(Psi>* psi, psi::Psi* kspw_psi, hamilt::Hamilt* p_hamilt, const pseudopot_cell_vnl& nlpp, - std::ofstream& ofs_running, - const bool is_already_initpsi) + std::ofstream& ofs_running) { ModuleBase::timer::tick("PSIInit", "initialize_psi"); @@ -255,8 +254,7 @@ void PSIInit::initialize_psi(Psi>* psi, } else { - //! note: is_already_initpsi will be false in init_after_vc when vc changes. - if (PARAM.inp.basis_type == "pw" && is_already_initpsi == false) + if (PARAM.inp.basis_type == "pw") { for (int ik = 0; ik < this->pw_wfc->nks; ++ik) { diff --git a/source/module_psi/psi_init.h b/source/module_psi/psi_init.h index df0ec83446..5dc0d95ac6 100644 --- a/source/module_psi/psi_init.h +++ b/source/module_psi/psi_init.h @@ -48,14 +48,12 @@ class PSIInit * @param psi store the wavefunction * @param p_hamilt Hamiltonian operator * @param ofs_running output stream for running information - * @param is_already_initpsi whether psi has been initialized */ void initialize_psi(Psi>* psi, psi::Psi* kspw_psi, hamilt::Hamilt* p_hamilt, const pseudopot_cell_vnl& nlpp, - std::ofstream& ofs_running, - const bool is_already_initpsi); + std::ofstream& ofs_running); /** * @brief get the psi_initializer diff --git a/source/module_psi/wavefunc.cpp b/source/module_psi/wavefunc.cpp index c51bcb3048..908346ab55 100644 --- a/source/module_psi/wavefunc.cpp +++ b/source/module_psi/wavefunc.cpp @@ -61,12 +61,13 @@ psi::Psi>* wavefunc::allocate(const int nkstot, const int n wanf2[0].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol); // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int - const size_t memory_cost = sizeof(std::complex) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx); + const size_t memory_cost + = sizeof(std::complex) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx); std::cout << " Memory for wanf2 (MB): " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("WF::wanf2", memory_cost); } - + // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int const size_t memory_cost = sizeof(std::complex) * PARAM.inp.nbands * (PARAM.globalv.npol * npwx); @@ -89,7 +90,8 @@ psi::Psi>* wavefunc::allocate(const int nkstot, const int n } // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int - const size_t memory_cost = sizeof(std::complex) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol); + const size_t memory_cost + = sizeof(std::complex) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol); std::cout << " Memory for wanf2 (MB): " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("WF::wanf2", memory_cost); @@ -184,175 +186,28 @@ int wavefunc::get_starting_nw() const namespace hamilt { -void diago_PAO_in_pw_k2(const int& ik, - psi::Psi>& wvf, +template <> +void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, + const int& ik, + psi::Psi, base_device::DEVICE_CPU>& wvf, ModulePW::PW_Basis_K* wfc_basis, wavefunc* p_wf, const ModuleBase::realArray& tab_at, const int& lmaxkb, - hamilt::Hamilt>* phm_in) + hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) { - ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); - - const int nbasis = wvf.get_nbasis(); - const int nbands = wvf.get_nbands(); - const int current_nbasis = wfc_basis->npwk[ik]; - - if (PARAM.inp.init_wfc == "file") - { - ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); - std::stringstream filename; - int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; - ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - - - std::vector> s_wfcatom(nbands * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis); - - if (PARAM.inp.ks_solver == "cg") - { - std::vector etfile(nbands, 0.0); - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etfile.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) - { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; - } - } - return; - } - - const int starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - { - return; - } - assert(starting_nw > 0); - std::vector etatom(starting_nw, 0.0); - - // special case here! use Psi(k-1) for the initialization of Psi(k) - // this method should be tested. - /*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0) - { - //this is memsaver case - if(wvf.get_nk() == 1) - { - return; - } - else - { - ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands()); - return; - } - } - */ - - if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0)) - { - p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis); - - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); - } - } - } - else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic") - { - ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); -} - - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - lmaxkb, - wfc_basis, - wfcatom, - tab_at, - PARAM.globalv.nqx, - PARAM.globalv.dq); - - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - - // (7) Diago with cg method. - std::vector> s_wfcatom(starting_nw * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, starting_nw * nbasis); - - // if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); - // this diagonalization method is obsoleted now - // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) - { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; - } - } - } + // TODO? float func } -void diago_PAO_in_pw_k2(const int& ik, - psi::Psi>& wvf, +template <> +void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, + const int& ik, + psi::Psi, base_device::DEVICE_CPU>& wvf, ModulePW::PW_Basis_K* wfc_basis, wavefunc* p_wf, const ModuleBase::realArray& tab_at, const int& lmaxkb, - hamilt::Hamilt>* phm_in) + hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) { ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); @@ -365,10 +220,9 @@ void diago_PAO_in_pw_k2(const int& ik, ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); std::stringstream filename; int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (PARAM.inp.ks_solver == "cg") { std::vector etfile(nbands, 0.0); @@ -399,23 +253,6 @@ void diago_PAO_in_pw_k2(const int& ik, return; } - // special case here! use Psi(k-1) for the initialization of Psi(k) - // this method should be tested. - /*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0) - { - //this is memsaver case - if(wvf.get_nk() == 1) - { - return; - } - else - { - ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands()); - return; - } - } - */ - const int starting_nw = p_wf->get_starting_nw(); if (starting_nw == 0) { @@ -459,7 +296,8 @@ void diago_PAO_in_pw_k2(const int& ik, PARAM.globalv.nqx, PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + if (PARAM.inp.init_wfc == "atomic+random" + && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } @@ -501,32 +339,6 @@ void diago_PAO_in_pw_k2(const int& ik, } } -template <> -void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, - const int& ik, - psi::Psi, base_device::DEVICE_CPU>& wvf, - ModulePW::PW_Basis_K* wfc_basis, - wavefunc* p_wf, - const ModuleBase::realArray& tab_at, - const int& lmaxkb, - hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) -{ - diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in); -} - -template <> -void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, - const int& ik, - psi::Psi, base_device::DEVICE_CPU>& wvf, - ModulePW::PW_Basis_K* wfc_basis, - wavefunc* p_wf, - const ModuleBase::realArray& tab_at, - const int& lmaxkb, - hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) -{ - diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in); -} - #if ((defined __CUDA) || (defined __ROCM)) template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, @@ -538,103 +350,9 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int& lmaxkb, hamilt::Hamilt, base_device::DEVICE_GPU>* phm_in) { - ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); - - const int nbasis = wvf.get_nbasis(); - const int nbands = wvf.get_nbands(); - const int current_nbasis = wfc_basis->npwk[ik]; - int starting_nw = nbands; - - ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); - if (PARAM.inp.init_wfc == "file") - { - std::stringstream filename; - int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; - ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - } - - starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - return; - assert(starting_nw > 0); - wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - - if (PARAM.inp.init_wfc.substr(0, 6) == "atomic") - { - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - lmaxkb, - wfc_basis, - wfcatom, - tab_at, - PARAM.globalv.nqx, - PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - } - else if (PARAM.inp.init_wfc == "random") - { - p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); - } - - std::complex* c_wfcatom = nullptr; - if (PARAM.inp.ks_solver != "bpcg") - { - // store wfcatom on the GPU - resmem_cd_op()(gpu_ctx, c_wfcatom, wfcatom.nr * wfcatom.nc); - castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, c_wfcatom, wfcatom.c, wfcatom.nr * wfcatom.nc); - } - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - // (7) Diago with cg method. - if (phm_in != nullptr) - { - std::vector etatom(starting_nw, 0.0); - hsolver::DiagoIterAssist, base_device::DEVICE_GPU>::diagH_subspace_init(phm_in, - c_wfcatom, - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - } - else - { - // this diagonalization method is obsoleted now - // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } - else if (PARAM.inp.ks_solver == "dav" || PARAM.inp.ks_solver == "dav_subspace") - { - assert(nbands <= wfcatom.nr); - // replace by haozhihan 2022-11-23 - hsolver::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, - nbands, - c_wfcatom, - wfcatom.nc, - &wvf(0, 0), - nbasis); - } - else if (PARAM.inp.ks_solver == "bpcg") - { - castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0, 0), wfcatom.c, wfcatom.nr * wfcatom.nc); - } - if (PARAM.inp.ks_solver != "bpcg") - { - delmem_cd_op()(gpu_ctx, c_wfcatom); - } + // TODO? float func } + template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int& ik, @@ -679,7 +397,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, tab_at, PARAM.globalv.nqx, PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + if (PARAM.inp.init_wfc == "atomic+random" + && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); }