diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 925298c983..08ad8751f8 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -764,6 +764,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ of_stress_pw.o\ symmetry_rho.o\ symmetry_rhog.o\ + setup_psi.o\ psi_init.o\ elecond.o\ sto_tool.o\ diff --git a/source/source_esolver/esolver_ks.cpp b/source/source_esolver/esolver_ks.cpp index 4768f6afd1..a4e8126ad0 100644 --- a/source/source_esolver/esolver_ks.cpp +++ b/source/source_esolver/esolver_ks.cpp @@ -334,10 +334,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i //---------------------------------------------------------------- // 2) compute magnetization, only for LSDA(spin==2) //---------------------------------------------------------------- - ucell.magnet.compute_mag(ucell.omega, - this->chr.nrxx, - this->chr.nxyz, - this->chr.rho, + ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho, this->pelec->nelec_spin.data()); //---------------------------------------------------------------- @@ -434,20 +431,15 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i MPI_Bcast(this->chr.rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD); #endif - //---------------------------------------------------------------- // 4) Update potentials (should be done every SF iter) - //---------------------------------------------------------------- - // Hamilt should be used after it is constructed. - // this->phamilt->update(conv_esolver); this->update_pot(ucell, istep, iter, conv_esolver); - //---------------------------------------------------------------- // 5) calculate energies - //---------------------------------------------------------------- // 1 means Harris-Foulkes functional // 2 means Kohn-Sham functional this->pelec->cal_energies(1); this->pelec->cal_energies(2); + if (iter == 1) { this->pelec->f_en.etot_old = this->pelec->f_en.etot; @@ -456,7 +448,6 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i this->pelec->f_en.etot_old = this->pelec->f_en.etot; - //---------------------------------------------------------------- // 6) time and meta-GGA //---------------------------------------------------------------- @@ -481,21 +472,15 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i #ifdef __RAPIDJSON - //---------------------------------------------------------------- // 7) add Json of scf mag - //---------------------------------------------------------------- - Json::add_output_scf_mag(ucell.magnet.tot_mag, - ucell.magnet.abs_mag, + Json::add_output_scf_mag(ucell.magnet.tot_mag, ucell.magnet.abs_mag, this->pelec->f_en.etot * ModuleBase::Ry_to_eV, this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV, - drho, - duration); + drho, duration); #endif //__RAPIDJSON - //---------------------------------------------------------------- // 7) SCF restart information - //---------------------------------------------------------------- if (PARAM.inp.mixing_restart > 0 && iter == this->p_chgmix->mixing_restart_step - 1 && iter != PARAM.inp.scf_nmax) @@ -504,9 +489,7 @@ void ESolver_KS::iter_finish(UnitCell& ucell, const int istep, int& i std::cout << " SCF restart after this step!" << std::endl; } - //---------------------------------------------------------------- // 8) Iter finish - //---------------------------------------------------------------- ESolver_FP::iter_finish(ucell, istep, iter, conv_esolver); } diff --git a/source/source_esolver/esolver_ks_lcaopw.cpp b/source/source_esolver/esolver_ks_lcaopw.cpp index 8597fa6847..7a6d78e339 100644 --- a/source/source_esolver/esolver_ks_lcaopw.cpp +++ b/source/source_esolver/esolver_ks_lcaopw.cpp @@ -81,7 +81,7 @@ namespace ModuleESolver void ESolver_KS_LIP::before_scf(UnitCell& ucell, const int istep) { ESolver_KS_PW::before_scf(ucell, istep); - this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); + this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running); } template @@ -89,9 +89,9 @@ namespace ModuleESolver { ESolver_KS_PW::before_all_runners(ucell, inp); delete this->psi_local; - this->psi_local = new psi::Psi(this->psi->get_nk(), - this->p_psi_init->psi_initer->nbands_start(), - this->psi->get_nbasis(), + this->psi_local = new psi::Psi(this->stp.psi_cpu->get_nk(), + this->stp.p_psi_init->psi_initer->nbands_start(), + this->stp.psi_cpu->get_nbasis(), this->kv.ngk, true); #ifdef __EXX @@ -105,13 +105,12 @@ namespace ModuleESolver ucell.symm, &this->kv, this->psi_local, - this->kspw_psi, + this->stp.psi_t, this->pw_wfc, this->pw_rho, this->sf, &ucell, this->pelec)); - // this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec); } } #endif @@ -147,7 +146,8 @@ namespace ModuleESolver bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false; hsolver::HSolverLIP hsolver_lip_obj(this->pw_wfc); - hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); + hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec, + *this->psi_local, skip_charge,ucell.tpiba,ucell.nat); // add exx #ifdef __EXX @@ -244,7 +244,7 @@ namespace ModuleESolver ModuleIO::write_Vxc(PARAM.inp.nspin, PARAM.globalv.nlocal, GlobalV::DRANK, - *this->kspw_psi, + *this->stp.psi_t, ucell, this->sf, this->solvent, diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 97d4f61533..d2e2612070 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -49,17 +49,9 @@ ESolver_KS_PW::~ESolver_KS_PW() // delete Hamilt this->deallocate_hamilt(); - if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") - { - delete this->kspw_psi; - } - if (PARAM.inp.precision == "single") - { - delete this->__kspw_psi; - } + // mohan add 2025-10-12 + this->stp.clean(); - delete this->psi; - delete this->p_psi_init; } template @@ -89,18 +81,7 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho, this->pw_rhod, this->pw_big, this->solvent, inp); - //! Allocate and initialize psi - this->p_psi_init = new psi::PSIInit(inp.init_wfc, - inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, - this->sf, this->kv, this->ppcell, *this->pw_wfc); - - allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max); - - this->p_psi_init->prepare_init(inp.pw_seed); - - this->kspw_psi = inp.device == "gpu" || inp.precision == "single" - ? new psi::Psi(this->psi[0]) - : reinterpret_cast*>(this->psi); + this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS"); @@ -142,7 +123,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma); - this->p_psi_init->prepare_init(PARAM.inp.pw_seed); + this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed); } //! Init Hamiltonian (cell changed) @@ -156,14 +137,10 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) //! Setup potentials (local, non-local, sc, +U, DFT-1/2) pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid, this->chr, this->locpp, this->ppcell, this->vsep_cell, - this->kspw_psi, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); + this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp); - //! Initialize wave functions - if (!this->already_initpsi) - { - this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running); - this->already_initpsi = true; - } + + this->stp.init(this->p_hamilt); //! Exx calculations if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" @@ -173,7 +150,7 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) { auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); hamilt_pw->set_exx_helper(exx_helper); - exx_helper.set_psi(kspw_psi); + exx_helper.set_psi(this->stp.psi_t); } } @@ -202,7 +179,7 @@ void ESolver_KS_PW::iter_init(UnitCell& ucell, const int istep, const // new DFT+U method will calculate energy when evaluating the Hamiltonian if (dftu->omc != 2) { - dftu->cal_occ_pw(iter, this->kspw_psi, this->pelec->wg, ucell, PARAM.inp.mixing_beta); + dftu->cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta); } dftu->output(ucell); } @@ -271,7 +248,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste PARAM.inp.use_k_continuity); hsolver_pw_obj.solve(this->p_hamilt, - this->kspw_psi[0], + this->stp.psi_t[0], this->pelec, this->pelec->ekb.c, GlobalV::RANK_IN_POOL, @@ -316,7 +293,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int // Related to EXX if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) { - this->pelec->set_exx(exx_helper.cal_exx_energy(kspw_psi)); + this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t)); } // deband is calculated from "output" charge density @@ -347,12 +324,12 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int double dexx = 0.0; if (PARAM.inp.exx_thr_type == "energy") { - dexx = exx_helper.cal_exx_energy(this->kspw_psi); + dexx = exx_helper.cal_exx_energy(this->stp.psi_t); } - exx_helper.set_psi(this->kspw_psi); + exx_helper.set_psi(this->stp.psi_t); if (PARAM.inp.exx_thr_type == "energy") { - dexx -= exx_helper.cal_exx_energy(this->kspw_psi); + dexx -= exx_helper.cal_exx_energy(this->stp.psi_t); // std::cout << "dexx = " << dexx << std::endl; } bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr; @@ -373,7 +350,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } else { - exx_helper.set_psi(this->kspw_psi); + exx_helper.set_psi(this->stp.psi_t); } } @@ -394,7 +371,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } // the output quantities - ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu, this->kv, this->pw_wfc, PARAM.inp); } @@ -409,24 +386,16 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const // sunliang 2025-04-10 if (PARAM.inp.out_elf[0] > 0) { - this->ESolver_KS::psi = new psi::Psi(this->psi[0]); + this->ESolver_KS::psi = new psi::Psi(this->stp.psi_cpu[0]); } // Call 'after_scf' of ESolver_KS ESolver_KS::after_scf(ucell, istep, conv_esolver); - // Transfer data from GPU to CPU in pw basis - if (this->device == base_device::GpuDevice) - { - castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(), - this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(), - this->psi[0].size()); - } - // Output quantities ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, - this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); + this->pw_rho, this->pw_rhod, this->pw_big, this->stp, + this->ctx, this->device, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -442,20 +411,13 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo { Forces ff(ucell.nat); - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); + // mohan add 2025-10-12 + this->stp.update_psi_d(); // Calculate forces ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, &this->sf, this->solvent, &this->locpp, &this->ppcell, - &this->kv, this->pw_wfc, this->__kspw_psi); + &this->kv, this->pw_wfc, this->stp.psi_d); } template @@ -463,18 +425,11 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s { Stress_PW ss(this->pelec); - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); + // mohan add 2025-10-12 + this->stp.update_psi_d(); ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod, - &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi); + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d); // external stress double unit_transform = 0.0; @@ -492,9 +447,8 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) ESolver_KS::after_all_runners(ucell); ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi, - this->kspw_psi, this->__kspw_psi, this->sf, - this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); + this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp, + this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); elecstate::teardown_estate_pw(this->pelec, this->vsep_cell); diff --git a/source/source_esolver/esolver_ks_pw.h b/source/source_esolver/esolver_ks_pw.h index 523ae91939..acdd7083ee 100644 --- a/source/source_esolver/esolver_ks_pw.h +++ b/source/source_esolver/esolver_ks_pw.h @@ -1,7 +1,7 @@ #ifndef ESOLVER_KS_PW_H #define ESOLVER_KS_PW_H #include "./esolver_ks.h" -#include "source_psi/psi_init.h" +#include "source_psi/setup_psi.h" // mohan add 20251012 #include "source_pw/module_pwdft/VSep_in_pw.h" #include "source_pw/module_pwdft/global.h" #include "source_pw/module_pwdft/module_exx_helper/exx_helper.h" @@ -54,27 +54,18 @@ class ESolver_KS_PW : public ESolver_KS virtual void allocate_hamilt(const UnitCell& ucell); virtual void deallocate_hamilt(); - //! hide the psi in ESolver_KS for tmp use - psi::Psi, base_device::DEVICE_CPU>* psi = nullptr; - - // psi_initializer controller - psi::PSIInit* p_psi_init = nullptr; + // Electronic wave function psi + Setup_Psi stp; // DFT-1/2 method VSep* vsep_cell = nullptr; + // for get_pchg and get_wf, use ctx as input of fft Device* ctx = {}; + // for device to host data transformation base_device::AbacusDevice_t device = {}; - psi::Psi* kspw_psi = nullptr; - - psi::Psi, Device>* __kspw_psi = nullptr; - - bool already_initpsi = false; - - using castmem_2d_d2h_op - = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; }; } // namespace ModuleESolver #endif diff --git a/source/source_esolver/esolver_sdft_pw.cpp b/source/source_esolver/esolver_sdft_pw.cpp index 170147ba06..1a9057d178 100644 --- a/source/source_esolver/esolver_sdft_pw.cpp +++ b/source/source_esolver/esolver_sdft_pw.cpp @@ -176,8 +176,8 @@ void ESolver_SDFT_PW::hamilt2rho_single(UnitCell& ucell, int istep, i hsolver_pw_sdft_obj.solve(ucell, this->p_hamilt, - this->kspw_psi[0], - this->psi[0], + this->stp.psi_t[0], + this->stp.psi_cpu[0], this->pelec, this->pw_wfc, this->stowf, @@ -233,7 +233,7 @@ void ESolver_SDFT_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& this->locpp, this->ppcell, ucell, - *this->kspw_psi, + *this->stp.psi_t, this->stowf); } @@ -248,7 +248,7 @@ void ESolver_SDFT_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& &this->sf, &this->kv, this->pw_wfc, - *this->kspw_psi, + *this->stp.psi_t, this->stowf, &this->chr, &this->locpp, @@ -279,7 +279,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) this->pw_wfc, &this->kv, this->pelec, - reinterpret_cast>*>(this->psi), + reinterpret_cast>*>(this->stp.psi_cpu), reinterpret_cast>*>(this->p_hamilt), this->stoche, reinterpret_cast, base_device::DEVICE_CPU>*>(&stowf)); @@ -301,7 +301,7 @@ void ESolver_SDFT_PW::after_all_runners(UnitCell& ucell) &this->kv, this->pelec, this->pw_wfc, - this->kspw_psi, + this->stp.psi_t, &this->ppcell, this->p_hamilt, this->stoche, diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index a8c588996a..a67d949ade 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -90,16 +90,18 @@ void ModuleIO::ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, const Device* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + // Transfer data from device (GPU) to host (CPU) in pw basis + stp.copy_d2h(device); + //---------------------------------------------------------- //! 4) Compute density of states (DOS) //---------------------------------------------------------- @@ -160,17 +162,10 @@ void ModuleIO::ctrl_scf_pw(const int istep, //------------------------------------------------------------------ if (inp.out_pchg.size() > 0) { - if (__kspw_psi != nullptr && inp.precision == "single") - { - delete reinterpret_cast, Device>*>(__kspw_psi); - } - - // Refresh __kspw_psi - __kspw_psi = inp.precision == "single" - ? new psi::Psi, Device>(kspw_psi[0]) - : reinterpret_cast, Device>*>(kspw_psi); + // update psi_d + stp.update_psi_d(); - const int nbands = kspw_psi->get_nbands(); + const int nbands = stp.psi_t->get_nbands(); const int ngmc = chr.ngmc; ModuleIO::get_pchg_pw(inp.out_pchg, @@ -179,7 +174,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, pw_rhod->nxyz, ngmc, &ucell, - __kspw_psi, + stp.psi_d, pw_rhod, pw_wfc, ctx, @@ -207,7 +202,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, inp.nnkpfile, inp.wannier_spin); wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi); + wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, stp.psi_cpu); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); } @@ -219,7 +214,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); berryphase bp; - bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi, pw_rho, pw_wfc, kv); + bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, stp.psi_cpu, pw_rho, pw_wfc, kv); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); } @@ -241,7 +236,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(stp.psi_t), pelec->wg); } @@ -257,9 +252,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -276,7 +269,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.out_ldos[0]) { ModuleIO::cal_ldos_pw(reinterpret_cast>*>(pelec), - psi[0], para_grid, ucell); + stp.psi_cpu[0], para_grid, ucell); } //---------------------------------------------------------- @@ -296,7 +289,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, << " a.u." << std::endl; } Numerical_Basis numerical_basis; - numerical_basis.output_overlap(psi[0], sf, kv, pw_wfc, ucell, i); + numerical_basis.output_overlap(stp.psi_cpu[0], sf, kv, pw_wfc, ucell, i); } ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); } @@ -307,23 +300,15 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) { - if (__kspw_psi != nullptr && inp.precision == "single") - { - delete reinterpret_cast, Device>*>(__kspw_psi); - } - - // Refresh __kspw_psi - __kspw_psi = inp.precision == "single" - ? new psi::Psi, Device>(kspw_psi[0]) - : reinterpret_cast, Device>*>(kspw_psi); + stp.update_psi_d(); ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, - kspw_psi->get_nbands(), + stp.psi_t->get_nbands(), inp.nspin, pw_rhod->nxyz, &ucell, - __kspw_psi, + stp.psi_d, pw_wfc, ctx, para_grid, @@ -339,7 +324,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, if (inp.cal_cond) { using Real = typename GetTypeReal::type; - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, stp.psi_t, &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -376,7 +361,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - kspw_psi, + stp.psi_t, pelec, pw_wfc, pw_rho, @@ -399,10 +384,9 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -417,10 +401,9 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, const base_device::DEVICE_CPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -436,10 +419,9 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -454,10 +436,9 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, const base_device::DEVICE_GPU* ctx, + const base_device::AbacusDevice_t &device, const Parallel_Grid ¶_grid, const Input_para& inp); #endif @@ -471,9 +452,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -490,9 +469,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_CPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -510,9 +487,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, @@ -529,9 +504,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device - psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Setup_Psi, base_device::DEVICE_GPU> &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 87fea245b0..b10870c9d1 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -4,6 +4,7 @@ #include "source_base/module_device/device.h" // use Device #include "source_psi/psi.h" // define psi #include "source_estate/elecstate_lcao.h" // use pelec +#include "source_psi/setup_psi.h" // use Setup_Psi class namespace ModuleIO { @@ -28,10 +29,9 @@ void ctrl_scf_pw(const int istep, const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, const Device* ctx, + const base_device::AbacusDevice_t &device, // mohan add 2025-10-15 const Parallel_Grid ¶_grid, const Input_para& inp); @@ -44,9 +44,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rhod, Charge &chr, K_Vectors &kv, - psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi* kspw_psi, - psi::Psi, Device>* __kspw_psi, + Setup_Psi &stp, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, diff --git a/source/source_psi/CMakeLists.txt b/source/source_psi/CMakeLists.txt index a1037885f3..f871d2feee 100644 --- a/source/source_psi/CMakeLists.txt +++ b/source/source_psi/CMakeLists.txt @@ -7,6 +7,7 @@ add_library( add_library( psi_overall_init OBJECT + setup_psi.cpp psi_init.cpp ) @@ -22,6 +23,7 @@ add_library( psi_initializer_nao_random.cpp ) + if(ENABLE_COVERAGE) add_coverage(psi) add_coverage(psi_initializer) @@ -32,4 +34,4 @@ if (BUILD_TESTING) if(ENABLE_MPI) add_subdirectory(test) endif() -endif() \ No newline at end of file +endif() diff --git a/source/source_psi/setup_psi.cpp b/source/source_psi/setup_psi.cpp new file mode 100644 index 0000000000..84ef806bd9 --- /dev/null +++ b/source/source_psi/setup_psi.cpp @@ -0,0 +1,105 @@ +#include "source_psi/setup_psi.h" +#include "source_lcao/setup_deepks.h" +#include "source_lcao/LCAO_domain.h" +#include "source_io/module_parameter/parameter.h" // use parameter + +template +Setup_Psi::Setup_Psi(){} + +template +Setup_Psi::~Setup_Psi(){} + +template +void Setup_Psi::before_runner( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp) +{ + //! Allocate and initialize psi + this->p_psi_init = new psi::PSIInit(inp.init_wfc, + inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell, + sf, kv, ppcell, pw_wfc); + + //! Allocate memory for cpu version of psi + allocate_psi(this->psi_cpu, kv.get_nks(), kv.ngk, PARAM.globalv.nbands_l, pw_wfc.npwk_max); + + this->p_psi_init->prepare_init(inp.pw_seed); + + //! If GPU or single precision, allocate a new psi (psi_t). + //! otherwise, transform psi_cpu to psi_t + this->psi_t = inp.device == "gpu" || inp.precision == "single" + ? new psi::Psi(this->psi_cpu[0]) + : reinterpret_cast*>(this->psi_cpu); +} + + +template +void Setup_Psi::update_psi_d() +{ + if (this->psi_d != nullptr && PARAM.inp.precision == "single") + { + delete reinterpret_cast, Device>*>(this->psi_d); + } + + // Refresh this->psi_d + this->psi_d = PARAM.inp.precision == "single" + ? new psi::Psi, Device>(this->psi_t[0]) + : reinterpret_cast, Device>*>(this->psi_t); +} + +template +void Setup_Psi::init(hamilt::Hamilt* p_hamilt) +{ + //! Initialize wave functions + if (!this->already_initpsi) + { + this->p_psi_init->initialize_psi(this->psi_cpu, this->psi_t, p_hamilt, GlobalV::ofs_running); + this->already_initpsi = true; + } +} + + +// Transfer data from GPU to CPU in pw basis +template +void Setup_Psi::copy_d2h(const base_device::AbacusDevice_t &device) +{ + if (device == base_device::GpuDevice) + { + castmem_2d_d2h_op()(this->psi_cpu[0].get_pointer() - this->psi_cpu[0].get_psi_bias(), + this->psi_t[0].get_pointer() - this->psi_t[0].get_psi_bias(), + this->psi_cpu[0].size()); + } + else + { + // do nothing + } + return; +} + + + +template +void Setup_Psi::clean() +{ + if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") + { + delete this->psi_t; + } + if (PARAM.inp.precision == "single") + { + delete this->psi_d; + } + + delete this->psi_cpu; + delete this->p_psi_init; +} + +template class Setup_Psi, base_device::DEVICE_CPU>; +template class Setup_Psi, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class Setup_Psi, base_device::DEVICE_GPU>; +template class Setup_Psi, base_device::DEVICE_GPU>; +#endif diff --git a/source/source_psi/setup_psi.h b/source/source_psi/setup_psi.h new file mode 100644 index 0000000000..40a68a5020 --- /dev/null +++ b/source/source_psi/setup_psi.h @@ -0,0 +1,74 @@ +#ifndef SETUP_PSI_H +#define SETUP_PSI_H + +#include "source_psi/psi_init.h" +#include "source_cell/unitcell.h" +#include "source_cell/klist.h" +#include "source_pw/module_pwdft/structure_factor.h" +#include "source_basis/module_pw/pw_basis_k.h" +#include "source_pw/module_pwdft/VNL_in_pw.h" +#include "source_io/module_parameter/input_parameter.h" +#include "source_base/module_device/device.h" +#include "source_hamilt/hamilt.h" + +template +class Setup_Psi +{ + public: + + Setup_Psi(); + ~Setup_Psi(); + + //------------ + // variables + // psi_cpu, complex on cpu + // psi_t, complex on cpu/gpu + // psi_d, complex on cpu/gpu + //------------ + + // originally, this term is psi + // for PW, we have psi_cpu + psi::Psi, base_device::DEVICE_CPU>* psi_cpu = nullptr; + + // originally, this term is kspw_psi + // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy + psi::Psi* psi_t = nullptr; + + // originally, this term is __kspw_psi + psi::Psi, Device>* psi_d = nullptr; + + // psi_initializer controller + psi::PSIInit* p_psi_init = nullptr; + + bool already_initpsi = false; + + //------------ + // functions + //------------ + + void before_runner( + const UnitCell &ucell, + const K_Vectors &kv, + const Structure_Factor &sf, + const ModulePW::PW_Basis_K &pw_wfc, + const pseudopot_cell_vnl &ppcell, + const Input_para &inp); + + void init(hamilt::Hamilt* p_hamilt); + + void update_psi_d(); + + // Transfer data from device to host in pw basis + void copy_d2h(const base_device::AbacusDevice_t &device); + + void clean(); + + private: + + using castmem_2d_d2h_op + = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; + +}; + + +#endif