Skip to content
Merged
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
of_stress_pw.o\
symmetry_rho.o\
symmetry_rhog.o\
setup_psi.o\
psi_init.o\
elecond.o\
sto_tool.o\
Expand Down
25 changes: 4 additions & 21 deletions source/source_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
//----------------------------------------------------------------
// 2) compute magnetization, only for LSDA(spin==2)
//----------------------------------------------------------------
ucell.magnet.compute_mag(ucell.omega,
this->chr.nrxx,
this->chr.nxyz,
this->chr.rho,
ucell.magnet.compute_mag(ucell.omega, this->chr.nrxx, this->chr.nxyz, this->chr.rho,
this->pelec->nelec_spin.data());

//----------------------------------------------------------------
Expand Down Expand Up @@ -434,20 +431,15 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
MPI_Bcast(this->chr.rho[0], this->pw_rhod->nrxx, MPI_DOUBLE, 0, BP_WORLD);
#endif

//----------------------------------------------------------------
// 4) Update potentials (should be done every SF iter)
//----------------------------------------------------------------
// Hamilt should be used after it is constructed.
// this->phamilt->update(conv_esolver);
this->update_pot(ucell, istep, iter, conv_esolver);

//----------------------------------------------------------------
// 5) calculate energies
//----------------------------------------------------------------
// 1 means Harris-Foulkes functional
// 2 means Kohn-Sham functional
this->pelec->cal_energies(1);
this->pelec->cal_energies(2);

if (iter == 1)
{
this->pelec->f_en.etot_old = this->pelec->f_en.etot;
Expand All @@ -456,7 +448,6 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
this->pelec->f_en.etot_old = this->pelec->f_en.etot;



//----------------------------------------------------------------
// 6) time and meta-GGA
//----------------------------------------------------------------
Expand All @@ -481,21 +472,15 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i


#ifdef __RAPIDJSON
//----------------------------------------------------------------
// 7) add Json of scf mag
//----------------------------------------------------------------
Json::add_output_scf_mag(ucell.magnet.tot_mag,
ucell.magnet.abs_mag,
Json::add_output_scf_mag(ucell.magnet.tot_mag, ucell.magnet.abs_mag,
this->pelec->f_en.etot * ModuleBase::Ry_to_eV,
this->pelec->f_en.etot_delta * ModuleBase::Ry_to_eV,
drho,
duration);
drho, duration);
#endif //__RAPIDJSON


//----------------------------------------------------------------
// 7) SCF restart information
//----------------------------------------------------------------
if (PARAM.inp.mixing_restart > 0
&& iter == this->p_chgmix->mixing_restart_step - 1
&& iter != PARAM.inp.scf_nmax)
Expand All @@ -504,9 +489,7 @@ void ESolver_KS<T, Device>::iter_finish(UnitCell& ucell, const int istep, int& i
std::cout << " SCF restart after this step!" << std::endl;
}

//----------------------------------------------------------------
// 8) Iter finish
//----------------------------------------------------------------
ESolver_FP::iter_finish(ucell, istep, iter, conv_esolver);
}

Expand Down
16 changes: 8 additions & 8 deletions source/source_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ namespace ModuleESolver
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
{
ESolver_KS_PW<T>::before_scf(ucell, istep);
this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
this->stp.p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
}

template <typename T>
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
{
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
delete this->psi_local;
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
this->p_psi_init->psi_initer->nbands_start(),
this->psi->get_nbasis(),
this->psi_local = new psi::Psi<T>(this->stp.psi_cpu->get_nk(),
this->stp.p_psi_init->psi_initer->nbands_start(),
this->stp.psi_cpu->get_nbasis(),
this->kv.ngk,
true);
#ifdef __EXX
Expand All @@ -105,13 +105,12 @@ namespace ModuleESolver
ucell.symm,
&this->kv,
this->psi_local,
this->kspw_psi,
this->stp.psi_t,
this->pw_wfc,
this->pw_rho,
this->sf,
&ucell,
this->pelec));
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
}
}
#endif
Expand Down Expand Up @@ -147,7 +146,8 @@ namespace ModuleESolver
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat);
hsolver_lip_obj.solve(this->p_hamilt, this->stp.psi_t[0], this->pelec,
*this->psi_local, skip_charge,ucell.tpiba,ucell.nat);

// add exx
#ifdef __EXX
Expand Down Expand Up @@ -244,7 +244,7 @@ namespace ModuleESolver
ModuleIO::write_Vxc(PARAM.inp.nspin,
PARAM.globalv.nlocal,
GlobalV::DRANK,
*this->kspw_psi,
*this->stp.psi_t,
ucell,
this->sf,
this->solvent,
Expand Down
99 changes: 28 additions & 71 deletions source/source_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,9 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
// delete Hamilt
this->deallocate_hamilt();

if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delete this->kspw_psi;
}
if (PARAM.inp.precision == "single")
{
delete this->__kspw_psi;
}
// mohan add 2025-10-12
this->stp.clean();

delete this->psi;
delete this->p_psi_init;
}

template <typename T, typename Device>
Expand Down Expand Up @@ -89,18 +81,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
this->locpp, this->ppcell, this->vsep_cell, this->pw_wfc, this->pw_rho,
this->pw_rhod, this->pw_big, this->solvent, inp);

//! Allocate and initialize psi
this->p_psi_init = new psi::PSIInit<T, Device>(inp.init_wfc,
inp.ks_solver, inp.basis_type, GlobalV::MY_RANK, ucell,
this->sf, this->kv, this->ppcell, *this->pw_wfc);

allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk, PARAM.globalv.nbands_l, this->pw_wfc->npwk_max);

this->p_psi_init->prepare_init(inp.pw_seed);

this->kspw_psi = inp.device == "gpu" || inp.precision == "single"
? new psi::Psi<T, Device>(this->psi[0])
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
this->stp.before_runner(ucell, this->kv, this->sf, *this->pw_wfc, this->ppcell, PARAM.inp);

ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT BASIS");

Expand Down Expand Up @@ -142,7 +123,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)

this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);

this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
this->stp.p_psi_init->prepare_init(PARAM.inp.pw_seed);
}

//! Init Hamiltonian (cell changed)
Expand All @@ -156,14 +137,10 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
//! Setup potentials (local, non-local, sc, +U, DFT-1/2)
pw::setup_pot(istep, ucell, this->kv, this->sf, this->pelec, this->Pgrid,
this->chr, this->locpp, this->ppcell, this->vsep_cell,
this->kspw_psi, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);
this->stp.psi_t, this->p_hamilt, this->pw_wfc, this->pw_rhod, PARAM.inp);

//! Initialize wave functions
if (!this->already_initpsi)
{
this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
this->already_initpsi = true;
}

this->stp.init(this->p_hamilt);

//! Exx calculations
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
Expand All @@ -173,7 +150,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
{
auto hamilt_pw = reinterpret_cast<hamilt::HamiltPW<T, Device>*>(this->p_hamilt);
hamilt_pw->set_exx_helper(exx_helper);
exx_helper.set_psi(kspw_psi);
exx_helper.set_psi(this->stp.psi_t);
}
}

Expand Down Expand Up @@ -202,7 +179,7 @@ void ESolver_KS_PW<T, Device>::iter_init(UnitCell& ucell, const int istep, const
// new DFT+U method will calculate energy when evaluating the Hamiltonian
if (dftu->omc != 2)
{
dftu->cal_occ_pw(iter, this->kspw_psi, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
dftu->cal_occ_pw(iter, this->stp.psi_t, this->pelec->wg, ucell, PARAM.inp.mixing_beta);
}
dftu->output(ucell);
}
Expand Down Expand Up @@ -271,7 +248,7 @@ void ESolver_KS_PW<T, Device>::hamilt2rho_single(UnitCell& ucell, const int iste
PARAM.inp.use_k_continuity);

hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
this->stp.psi_t[0],
this->pelec,
this->pelec->ekb.c,
GlobalV::RANK_IN_POOL,
Expand Down Expand Up @@ -316,7 +293,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
// Related to EXX
if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter)
{
this->pelec->set_exx(exx_helper.cal_exx_energy(kspw_psi));
this->pelec->set_exx(exx_helper.cal_exx_energy(this->stp.psi_t));
}

// deband is calculated from "output" charge density
Expand Down Expand Up @@ -347,12 +324,12 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
double dexx = 0.0;
if (PARAM.inp.exx_thr_type == "energy")
{
dexx = exx_helper.cal_exx_energy(this->kspw_psi);
dexx = exx_helper.cal_exx_energy(this->stp.psi_t);
}
exx_helper.set_psi(this->kspw_psi);
exx_helper.set_psi(this->stp.psi_t);
if (PARAM.inp.exx_thr_type == "energy")
{
dexx -= exx_helper.cal_exx_energy(this->kspw_psi);
dexx -= exx_helper.cal_exx_energy(this->stp.psi_t);
// std::cout << "dexx = " << dexx << std::endl;
}
bool conv_ene = std::abs(dexx) < PARAM.inp.exx_ene_thr;
Expand All @@ -373,7 +350,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
}
else
{
exx_helper.set_psi(this->kspw_psi);
exx_helper.set_psi(this->stp.psi_t);
}
}

Expand All @@ -394,7 +371,7 @@ void ESolver_KS_PW<T, Device>::iter_finish(UnitCell& ucell, const int istep, int
}

// the output quantities
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi,
ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->stp.psi_cpu,
this->kv, this->pw_wfc, PARAM.inp);
}

Expand All @@ -409,24 +386,19 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
// sunliang 2025-04-10
if (PARAM.inp.out_elf[0] > 0)
{
this->ESolver_KS<T, Device>::psi = new psi::Psi<T>(this->psi[0]);
this->ESolver_KS<T, Device>::psi = new psi::Psi<T>(this->stp.psi_cpu[0]);
}

// Call 'after_scf' of ESolver_KS
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);

// Transfer data from GPU to CPU in pw basis
if (this->device == base_device::GpuDevice)
{
castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(),
this->psi[0].size());
}
this->stp.copy_g2c(this->device);

// Output quantities
ModuleIO::ctrl_scf_pw<T, Device>(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc,
this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi,
this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp);
this->pw_rho, this->pw_rhod, this->pw_big, this->stp,
this->ctx, this->Pgrid, PARAM.inp);

ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
}
Expand All @@ -442,39 +414,25 @@ void ESolver_KS_PW<T, Device>::cal_force(UnitCell& ucell, ModuleBase::matrix& fo
{
Forces<double, Device> ff(ucell.nat);

if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
}

// Refresh __kspw_psi
this->__kspw_psi = PARAM.inp.precision == "single"
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
// mohan add 2025-10-12
this->stp.update_psi_d();

// Calculate forces
ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm,
&this->sf, this->solvent, &this->locpp, &this->ppcell,
&this->kv, this->pw_wfc, this->__kspw_psi);
&this->kv, this->pw_wfc, this->stp.psi_d);
}

template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::cal_stress(UnitCell& ucell, ModuleBase::matrix& stress)
{
Stress_PW<double, Device> ss(this->pelec);

if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
}

// Refresh __kspw_psi
this->__kspw_psi = PARAM.inp.precision == "single"
? new psi::Psi<std::complex<double>, Device>(this->kspw_psi[0])
: reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->kspw_psi);
// mohan add 2025-10-12
this->stp.update_psi_d();

ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod,
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi);
&ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->stp.psi_d);

// external stress
double unit_transform = 0.0;
Expand All @@ -492,9 +450,8 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
ESolver_KS<T, Device>::after_all_runners(ucell);

ModuleIO::ctrl_runner_pw<T, Device>(ucell, this->pelec, this->pw_wfc,
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi,
this->kspw_psi, this->__kspw_psi, this->sf,
this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp);
this->pw_rho, this->pw_rhod, this->chr, this->kv, this->stp,
this->sf, this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp);

elecstate::teardown_estate_pw<T, Device>(this->pelec, this->vsep_cell);

Expand Down
Loading
Loading