diff --git a/source/module_psi/wavefunc.cpp b/source/module_psi/wavefunc.cpp index c51bcb3048..63a168b05a 100644 --- a/source/module_psi/wavefunc.cpp +++ b/source/module_psi/wavefunc.cpp @@ -61,12 +61,13 @@ psi::Psi>* wavefunc::allocate(const int nkstot, const int n wanf2[0].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol); // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int - const size_t memory_cost = sizeof(std::complex) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx); + const size_t memory_cost + = sizeof(std::complex) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx); std::cout << " Memory for wanf2 (MB): " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("WF::wanf2", memory_cost); } - + // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int const size_t memory_cost = sizeof(std::complex) * PARAM.inp.nbands * (PARAM.globalv.npol * npwx); @@ -89,7 +90,8 @@ psi::Psi>* wavefunc::allocate(const int nkstot, const int n } // WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int - const size_t memory_cost = sizeof(std::complex) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol); + const size_t memory_cost + = sizeof(std::complex) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol); std::cout << " Memory for wanf2 (MB): " << static_cast(memory_cost) / 1024.0 / 1024.0 << std::endl; ModuleBase::Memory::record("WF::wanf2", memory_cost); @@ -184,175 +186,28 @@ int wavefunc::get_starting_nw() const namespace hamilt { -void diago_PAO_in_pw_k2(const int& ik, - psi::Psi>& wvf, +template <> +void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, + const int& ik, + psi::Psi, base_device::DEVICE_CPU>& wvf, ModulePW::PW_Basis_K* wfc_basis, wavefunc* p_wf, const ModuleBase::realArray& tab_at, const int& lmaxkb, - hamilt::Hamilt>* phm_in) + hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) { - ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); - - const int nbasis = wvf.get_nbasis(); - const int nbands = wvf.get_nbands(); - const int current_nbasis = wfc_basis->npwk[ik]; - - if (PARAM.inp.init_wfc == "file") - { - ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); - std::stringstream filename; - int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; - ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - - - std::vector> s_wfcatom(nbands * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis); - - if (PARAM.inp.ks_solver == "cg") - { - std::vector etfile(nbands, 0.0); - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etfile.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) - { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; - } - } - return; - } - - const int starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - { - return; - } - assert(starting_nw > 0); - std::vector etatom(starting_nw, 0.0); - - // special case here! use Psi(k-1) for the initialization of Psi(k) - // this method should be tested. - /*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0) - { - //this is memsaver case - if(wvf.get_nk() == 1) - { - return; - } - else - { - ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands()); - return; - } - } - */ - - if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0)) - { - p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis); - - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); - } - } - } - else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic") - { - ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); + // TODO float func } - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - lmaxkb, - wfc_basis, - wfcatom, - tab_at, - PARAM.globalv.nqx, - PARAM.globalv.dq); - - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - - // (7) Diago with cg method. - std::vector> s_wfcatom(starting_nw * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, starting_nw * nbasis); - - // if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); - // this diagonalization method is obsoleted now - // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) - { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; - } - } - } -} - -void diago_PAO_in_pw_k2(const int& ik, - psi::Psi>& wvf, +template <> +void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, + const int& ik, + psi::Psi, base_device::DEVICE_CPU>& wvf, ModulePW::PW_Basis_K* wfc_basis, wavefunc* p_wf, const ModuleBase::realArray& tab_at, const int& lmaxkb, - hamilt::Hamilt>* phm_in) + hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) { ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); @@ -365,10 +220,9 @@ void diago_PAO_in_pw_k2(const int& ik, ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); std::stringstream filename; int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (PARAM.inp.ks_solver == "cg") { std::vector etfile(nbands, 0.0); @@ -396,43 +250,19 @@ void diago_PAO_in_pw_k2(const int& ik, wvf(ib, ig) = wfcatom(ib, ig); } } - return; - } - - // special case here! use Psi(k-1) for the initialization of Psi(k) - // this method should be tested. - /*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0) - { - //this is memsaver case - if(wvf.get_nk() == 1) - { - return; - } - else - { - ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands()); - return; - } } - */ - - const int starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - { - return; - } - - assert(starting_nw > 0); - std::vector etatom(starting_nw, 0.0); - - if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0)) + else if (PARAM.inp.init_wfc == "random" + || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0)) { p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis); - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 + + if (PARAM.inp.ks_solver == "cg") { + std::vector etrandom(nbands, 0.0); + if (phm_in != nullptr) { - hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etatom.data()); + hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etrandom.data()); return; } else @@ -443,7 +273,15 @@ void diago_PAO_in_pw_k2(const int& ik, } else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic") { - ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc + const int starting_nw = p_wf->get_starting_nw(); + if (starting_nw == 0) + { + return; + } + assert(starting_nw > 0); + + ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); + if (PARAM.inp.test_wf) { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); @@ -459,7 +297,8 @@ void diago_PAO_in_pw_k2(const int& ik, PARAM.globalv.nqx, PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + if (PARAM.inp.init_wfc == "atomic+random" + && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } @@ -474,6 +313,7 @@ void diago_PAO_in_pw_k2(const int& ik, // if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 { + std::vector etatom(starting_nw, 0.0); if (phm_in != nullptr) { hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, @@ -501,33 +341,8 @@ void diago_PAO_in_pw_k2(const int& ik, } } -template <> -void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, - const int& ik, - psi::Psi, base_device::DEVICE_CPU>& wvf, - ModulePW::PW_Basis_K* wfc_basis, - wavefunc* p_wf, - const ModuleBase::realArray& tab_at, - const int& lmaxkb, - hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) -{ - diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in); -} - -template <> -void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, - const int& ik, - psi::Psi, base_device::DEVICE_CPU>& wvf, - ModulePW::PW_Basis_K* wfc_basis, - wavefunc* p_wf, - const ModuleBase::realArray& tab_at, - const int& lmaxkb, - hamilt::Hamilt, base_device::DEVICE_CPU>* phm_in) -{ - diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in); -} - #if ((defined __CUDA) || (defined __ROCM)) + template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int& ik, @@ -538,103 +353,9 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int& lmaxkb, hamilt::Hamilt, base_device::DEVICE_GPU>* phm_in) { - ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); - - const int nbasis = wvf.get_nbasis(); - const int nbands = wvf.get_nbands(); - const int current_nbasis = wfc_basis->npwk[ik]; - int starting_nw = nbands; - - ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); - if (PARAM.inp.init_wfc == "file") - { - std::stringstream filename; - int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); - filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; - ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - } - - starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - return; - assert(starting_nw > 0); - wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - - if (PARAM.inp.init_wfc.substr(0, 6) == "atomic") - { - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - lmaxkb, - wfc_basis, - wfcatom, - tab_at, - PARAM.globalv.nqx, - PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - } - else if (PARAM.inp.init_wfc == "random") - { - p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); - } - - std::complex* c_wfcatom = nullptr; - if (PARAM.inp.ks_solver != "bpcg") - { - // store wfcatom on the GPU - resmem_cd_op()(gpu_ctx, c_wfcatom, wfcatom.nr * wfcatom.nc); - castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, c_wfcatom, wfcatom.c, wfcatom.nr * wfcatom.nc); - } - if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02 - { - // (7) Diago with cg method. - if (phm_in != nullptr) - { - std::vector etatom(starting_nw, 0.0); - hsolver::DiagoIterAssist, base_device::DEVICE_GPU>::diagH_subspace_init(phm_in, - c_wfcatom, - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - } - else - { - // this diagonalization method is obsoleted now - // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } - else if (PARAM.inp.ks_solver == "dav" || PARAM.inp.ks_solver == "dav_subspace") - { - assert(nbands <= wfcatom.nr); - // replace by haozhihan 2022-11-23 - hsolver::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, - nbands, - c_wfcatom, - wfcatom.nc, - &wvf(0, 0), - nbasis); - } - else if (PARAM.inp.ks_solver == "bpcg") - { - castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0, 0), wfcatom.c, wfcatom.nr * wfcatom.nc); - } - if (PARAM.inp.ks_solver != "bpcg") - { - delmem_cd_op()(gpu_ctx, c_wfcatom); - } + // TODO float } + template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int& ik, @@ -679,7 +400,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, tab_at, PARAM.globalv.nqx, PARAM.globalv.dq); - if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + if (PARAM.inp.init_wfc == "atomic+random" + && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } @@ -742,6 +464,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, delmem_zd_op()(gpu_ctx, z_wfcatom); } } + #endif } // namespace hamilt