From 415b55f57189e64214e4b7e269b278f9304444f0 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 15:30:30 +0800 Subject: [PATCH 01/18] update small places in charge_mixing_residual.cpp --- .../module_charge/charge_mixing_residual.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/source/source_estate/module_charge/charge_mixing_residual.cpp b/source/source_estate/module_charge/charge_mixing_residual.cpp index 9e9c5e0131..8862ede35e 100644 --- a/source/source_estate/module_charge/charge_mixing_residual.cpp +++ b/source/source_estate/module_charge/charge_mixing_residual.cpp @@ -9,11 +9,13 @@ double Charge_Mixing::get_drho(Charge* chr, const double nelec) { ModuleBase::TITLE("Charge_Mixing", "get_drho"); ModuleBase::timer::tick("Charge_Mixing", "get_drho"); + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); double drho = 0.0; if (PARAM.inp.scf_thr_type == 1) { - for (int is = 0; is < PARAM.inp.nspin; ++is) + for (int is = 0; is < nspin; ++is) { ModuleBase::GlobalFunc::NOTE("Perform FFT on rho(r) to obtain rho(G)."); chr->rhopw->real2recip(chr->rho[is], chr->rhog[is]); @@ -23,15 +25,16 @@ double Charge_Mixing::get_drho(Charge* chr, const double nelec) } ModuleBase::GlobalFunc::NOTE("Calculate the charge difference between rho(G) and rho_save(G)"); - std::vector> drhog(PARAM.inp.nspin * this->rhopw->npw); + std::vector> drhog(nspin * this->rhopw->npw); #ifdef _OPENMP #pragma omp parallel for collapse(2) schedule(static, 512) #endif - for (int is = 0; is < PARAM.inp.nspin; ++is) + for (int is = 0; is < nspin; ++is) { + const int is_idx = is * this->rhopw->npw; for (int ig = 0; ig < this->rhopw->npw; ig++) { - drhog[is * rhopw->npw + ig] = chr->rhog[is][ig] - chr->rhog_save[is][ig]; + drhog[is_idx + ig] = chr->rhog[is][ig] - chr->rhog_save[is][ig]; } } @@ -42,7 +45,7 @@ double Charge_Mixing::get_drho(Charge* chr, const double nelec) { // Note: Maybe it is wrong. // The inner_product_real function (L1-norm) is different from that (L2-norm) in mixing. - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { if (is != 0 && is != 3 && PARAM.globalv.domag_z) { From 1156f2161b08b0bbca8314755c5f37f7fe359116 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 15:43:03 +0800 Subject: [PATCH 02/18] update charge_mixing_preconditioner --- .../charge_mixing_preconditioner.cpp | 90 +++++++++++++------ 1 file changed, 62 insertions(+), 28 deletions(-) diff --git a/source/source_estate/module_charge/charge_mixing_preconditioner.cpp b/source/source_estate/module_charge/charge_mixing_preconditioner.cpp index 804705be99..165b967ec8 100644 --- a/source/source_estate/module_charge/charge_mixing_preconditioner.cpp +++ b/source/source_estate/module_charge/charge_mixing_preconditioner.cpp @@ -6,21 +6,32 @@ void Charge_Mixing::Kerker_screen_recip(std::complex* drhog) { - if (this->mixing_gg0 <= 0.0 || this->mixing_beta <= 0.1) { - return; -} + ModuleBase::TITLE("Charge_Mixing", "Kerker_screen_recip"); + + if (this->mixing_gg0 <= 0.0 || this->mixing_beta <= 0.1) + { + return; + } + + ModuleBase::timer::tick("Charge_Mixing", "Kerker_screen_recip"); + + const int nspin = PARAM.inp.nspin; + double fac = 0.0; double gg0 = 0.0; double amin = 0.0; /// consider a resize for mixing_angle - int resize_tmp = 1; - if (PARAM.inp.nspin == 4 && this->mixing_angle > 0) { resize_tmp = 2; -} + int resize_tmp = 1; + if (nspin == 4 && this->mixing_angle > 0) + { + resize_tmp = 2; + } /// implement Kerker for density and magnetization separately - for (int is = 0; is < PARAM.inp.nspin / resize_tmp; ++is) + for (int is = 0; is < nspin / resize_tmp; ++is) { + const int is_idx = is * this->rhopw->npw; /// new mixing method only support nspin=2 not nspin=4 if (is >= 1) { @@ -29,10 +40,10 @@ void Charge_Mixing::Kerker_screen_recip(std::complex* drhog) #ifdef __DEBUG assert(is == 1); // make sure break works #endif - double is_mag = PARAM.inp.nspin - 1; + double is_mag = nspin - 1; //for (int ig = 0; ig < this->rhopw->npw * is_mag; ig++) //{ - // drhog[is * this->rhopw->npw + ig] *= 1; + // drhog[is_idx + ig] *= 1; //} break; } @@ -49,29 +60,45 @@ void Charge_Mixing::Kerker_screen_recip(std::complex* drhog) #ifdef _OPENMP #pragma omp parallel for schedule(static, 512) #endif + const double gg0_amin = this->mixing_gg0_min / amin; + for (int ig = 0; ig < this->rhopw->npw; ++ig) { double gg = this->rhopw->gg[ig]; - double filter_g = std::max(gg / (gg + gg0), this->mixing_gg0_min / amin); - drhog[is * this->rhopw->npw + ig] *= filter_g; + double filter_g = std::max(gg / (gg + gg0), gg0_min); + drhog[is_idx + ig] *= filter_g; } } + + ModuleBase::timer::tick("Charge_Mixing", "Kerker_screen_recip"); return; } void Charge_Mixing::Kerker_screen_real(double* drhor) { - if (this->mixing_gg0 <= 0.0001 || this->mixing_beta <= 0.1) { - return; -} - /// consider a resize for mixing_angle + ModuleBase::TITLE("Charge_Mixing", "Kerker_screen_real"); + + if (this->mixing_gg0 <= 0.0001 || this->mixing_beta <= 0.1) + { + return; + } + + ModuleBase::timer::tick("Charge_Mixing", "Kerker_screen_real"); + + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + + /// consider a resize for mixing_angle int resize_tmp = 1; - if (PARAM.inp.nspin == 4 && this->mixing_angle > 0) { resize_tmp = 2; -} + if (nspin == 4 && this->mixing_angle > 0) + { + resize_tmp = 2; + } - std::vector> drhog(this->rhopw->npw * PARAM.inp.nspin / resize_tmp); - std::vector drhor_filter(this->rhopw->nrxx * PARAM.inp.nspin / resize_tmp); - for (int is = 0; is < PARAM.inp.nspin / resize_tmp; ++is) + std::vector> drhog(this->rhopw->npw * nspin / resize_tmp); + std::vector drhor_filter(this->rhopw->nrxx * nspin / resize_tmp); + + for (int is = 0; is < nspin / resize_tmp; ++is) { // Note after this process some G which is higher than Gmax will be filtered. // Thus we cannot use Kerker_screen_recip(drhog.data()) directly after it. @@ -82,7 +109,7 @@ void Charge_Mixing::Kerker_screen_real(double* drhor) double gg0 = 0.0; double amin = 0.0; - for (int is = 0; is < PARAM.inp.nspin / resize_tmp; is++) + for (int is = 0; is < nspin / resize_tmp; is++) { if (is >= 1) @@ -92,8 +119,8 @@ void Charge_Mixing::Kerker_screen_real(double* drhor) #ifdef __DEBUG assert(is == 1); /// make sure break works #endif - double is_mag = PARAM.inp.nspin - 1; - if (PARAM.inp.nspin == 4 && this->mixing_angle > 0) { is_mag = 1; + double is_mag = nspin - 1; + if (nspin == 4 && this->mixing_angle > 0) { is_mag = 1; } for (int ig = 0; ig < this->rhopw->npw * is_mag; ig++) { @@ -114,21 +141,25 @@ void Charge_Mixing::Kerker_screen_real(double* drhor) #ifdef _OPENMP #pragma omp parallel for schedule(static, 512) #endif + + const int is_idx = is * this->rhopw->npw; + const double gg0_amin = this->mixing_gg0_min / amin; + for (int ig = 0; ig < this->rhopw->npw; ig++) { double gg = this->rhopw->gg[ig]; // I have not decided how to handle gg=0 part, will be changed in future //if (gg == 0) //{ - // drhog[is * this->rhopw->npw + ig] *= 0; + // drhog[is_idx + ig] *= 0; // continue; //} - double filter_g = std::max(gg / (gg + gg0), this->mixing_gg0_min / amin); - drhog[is * this->rhopw->npw + ig] *= (1 - filter_g); + double filter_g = std::max(gg / (gg + gg0), gg0_amin); + drhog[is_idx + ig] *= (1 - filter_g); } } /// inverse FT - for (int is = 0; is < PARAM.inp.nspin / resize_tmp; ++is) + for (int is = 0; is < nspin / resize_tmp; ++is) { this->rhopw->recip2real(drhog.data() + is * this->rhopw->npw, drhor_filter.data() + is * this->rhopw->nrxx); } @@ -136,8 +167,11 @@ void Charge_Mixing::Kerker_screen_real(double* drhor) #ifdef _OPENMP #pragma omp parallel for schedule(static, 512) #endif - for (int ir = 0; ir < this->rhopw->nrxx * PARAM.inp.nspin / resize_tmp; ir++) + for (int ir = 0; ir < this->rhopw->nrxx * nspin / resize_tmp; ir++) { drhor[ir] -= drhor_filter[ir]; } + + ModuleBase::timer::tick("Charge_Mixing", "Kerker_screen_real"); + return; } From 8d8803b5ee9d25289fa04d9fa7f504738930f90b Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 15:50:20 +0800 Subject: [PATCH 03/18] add timers and remove some PARAM.inp.nspin --- .../module_charge/charge_mixing_rho.cpp | 103 +++++++++++------- 1 file changed, 64 insertions(+), 39 deletions(-) diff --git a/source/source_estate/module_charge/charge_mixing_rho.cpp b/source/source_estate/module_charge/charge_mixing_rho.cpp index 6a965c5da2..38cd679f94 100644 --- a/source/source_estate/module_charge/charge_mixing_rho.cpp +++ b/source/source_estate/module_charge/charge_mixing_rho.cpp @@ -5,6 +5,12 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) { + ModuleBase::TITLE("Charge_Mixing", "mix_rho_recip"); + ModuleBase::timer::tick("Charge_Mixing", "mix_rho_recip"); + + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + std::complex* rhog_in = nullptr; std::complex* rhog_out = nullptr; // for smooth part @@ -26,7 +32,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) = std::bind(&Charge_Mixing::inner_product_recip_hartree, this, std::placeholders::_1, std::placeholders::_2); // DIIS Mixing Only for smooth part, while high_frequency part is mixed by plain mixing method. - if (PARAM.inp.nspin == 1) + if (nspin == 1) { rhog_in = rhogs_in; rhog_out = rhogs_out; @@ -35,17 +41,17 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhog_out); } - else if (PARAM.inp.nspin == 2) + else if (nspin == 2) { // magnetic density std::complex *rhog_mag = nullptr; std::complex *rhog_mag_save = nullptr; const int npw = this->rhopw->npw; // allocate rhog_mag[is*ngmc] and rhog_mag_save[is*ngmc] - rhog_mag = new std::complex[npw * PARAM.inp.nspin]; - rhog_mag_save = new std::complex[npw * PARAM.inp.nspin]; - ModuleBase::GlobalFunc::ZEROS(rhog_mag, npw * PARAM.inp.nspin); - ModuleBase::GlobalFunc::ZEROS(rhog_mag_save, npw * PARAM.inp.nspin); + rhog_mag = new std::complex[npw * nspin]; + rhog_mag_save = new std::complex[npw * nspin]; + ModuleBase::GlobalFunc::ZEROS(rhog_mag, npw * nspin); + ModuleBase::GlobalFunc::ZEROS(rhog_mag_save, npw * nspin); // get rhog_mag[is*ngmc] and rhog_mag_save[is*ngmc] for (int ig = 0; ig < npw; ig++) { @@ -84,7 +90,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhog_out); // get rhog[is][ngmc] from rhog_mag[is*ngmc] - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { ModuleBase::GlobalFunc::ZEROS(chr->rhog[is], npw); } @@ -106,7 +112,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) } } } - else if (PARAM.inp.nspin == 4 && PARAM.inp.mixing_angle <= 0) + else if (nspin == 4 && PARAM.inp.mixing_angle <= 0) { // normal broyden mixing for {rho, mx, my, mz} rhog_in = rhogs_in; @@ -135,7 +141,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhog_out); } - else if (PARAM.inp.nspin == 4 && PARAM.inp.mixing_angle > 0) + else if (nspin == 4 && PARAM.inp.mixing_angle > 0) { // special broyden mixing for {rho, |m|} proposed by J. Phys. Soc. Jpn. 82 (2013) 114706 // here only consider the case of mixing_angle = 1, which mean only change |m| and keep angle fixed @@ -154,9 +160,13 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) for (int ir = 0; ir < nrxx; ir++) { // |m| for rho - rho_magabs[ir] = std::sqrt(chr->rho[1][ir] * chr->rho[1][ir] + chr->rho[2][ir] * chr->rho[2][ir] + chr->rho[3][ir] * chr->rho[3][ir]); + rho_magabs[ir] = std::sqrt(chr->rho[1][ir] * chr->rho[1][ir] + + chr->rho[2][ir] * chr->rho[2][ir] + + chr->rho[3][ir] * chr->rho[3][ir]); // |m| for rho_save - rho_magabs_save[ir] = std::sqrt(chr->rho_save[1][ir] * chr->rho_save[1][ir] + chr->rho_save[2][ir] * chr->rho_save[2][ir] + chr->rho_save[3][ir] * chr->rho_save[3][ir]); + rho_magabs_save[ir] = std::sqrt(chr->rho_save[1][ir] * chr->rho_save[1][ir] + + chr->rho_save[2][ir] * chr->rho_save[2][ir] + + chr->rho_save[3][ir] * chr->rho_save[3][ir]); } // allocate memory for rhog_magabs and rhog_magabs_save const int npw = this->rhopw->npw; @@ -203,10 +213,14 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) // use new |m| and angle to update {mx, my, mz} for (int ig = 0; ig < npw; ig++) { - chr->rhog[0][ig] = rhog_magabs[ig]; // rhog - double norm = std::sqrt(chr->rho[1][ig] * chr->rho[1][ig] + chr->rho[2][ig] * chr->rho[2][ig] + chr->rho[3][ig] * chr->rho[3][ig]); - if (std::abs(norm) < 1e-10) { continue; -} + chr->rhog[0][ig] = rhog_magabs[ig]; // rhog + double norm = std::sqrt(chr->rho[1][ig] * chr->rho[1][ig] + + chr->rho[2][ig] * chr->rho[2][ig] + + chr->rho[3][ig] * chr->rho[3][ig]); + if (std::abs(norm) < 1e-10) + { + continue; + } double rescale_tmp = rho_magabs[npw + ig] / norm; chr->rho[1][ig] *= rescale_tmp; chr->rho[2][ig] *= rescale_tmp; @@ -222,7 +236,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) if ( PARAM.globalv.double_grid) { // plain mixing for high_frequencies - const int ndimhf = (this->rhodpw->npw - this->rhopw->npw) * PARAM.inp.nspin; + const int ndimhf = (this->rhodpw->npw - this->rhopw->npw) * nspin; this->mixing_highf->plain_mix(rhoghf_out, rhoghf_in, rhoghf_out, ndimhf, nullptr); // combine smooth part and high_frequency part @@ -231,7 +245,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) } // rhog to rho - if (PARAM.inp.nspin == 4 && PARAM.inp.mixing_angle > 0) + if (nspin == 4 && PARAM.inp.mixing_angle > 0) { // only tranfer rhog[0] // do not support double_grid, use rhopw directly @@ -239,7 +253,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) } else { - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { // use rhodpw for double_grid // rhodpw is the same as rhopw for ! PARAM.globalv.double_grid @@ -249,10 +263,10 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) // For kinetic energy density if ((XC_Functional::get_ked_flag()) && mixing_tau) { - std::vector> kin_g(PARAM.inp.nspin * rhodpw->npw); - std::vector> kin_g_save(PARAM.inp.nspin * rhodpw->npw); + std::vector> kin_g(nspin * rhodpw->npw); + std::vector> kin_g_save(nspin * rhodpw->npw); // FFT to get kin_g and kin_g_save - for (int is = 0; is < PARAM.inp.nspin; ++is) + for (int is = 0; is < nspin; ++is) { rhodpw->real2recip(chr->kin_r[is], &kin_g[is * rhodpw->npw]); rhodpw->real2recip(chr->kin_r_save[is], &kin_g_save[is * rhodpw->npw]); @@ -277,7 +291,7 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) if ( PARAM.globalv.double_grid) { // simple mixing for high_frequencies - const int ndimhf = (this->rhodpw->npw - this->rhopw->npw) * PARAM.inp.nspin; + const int ndimhf = (this->rhodpw->npw - this->rhopw->npw) * nspin; this->mixing_highf->plain_mix(taughf_out, taughf_in, taughf_out, ndimhf, nullptr); // combine smooth part and high_frequency part @@ -286,22 +300,28 @@ void Charge_Mixing::mix_rho_recip(Charge* chr) } // kin_g to kin_r - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { rhodpw->recip2real(&kin_g[is * rhodpw->npw], chr->kin_r[is]); } } - + ModuleBase::timer::tick("Charge_Mixing", "mix_rho_recip"); return; } void Charge_Mixing::mix_rho_real(Charge* chr) { + ModuleBase::TITLE("Charge_Mixing", "mix_rho_real"); + ModuleBase::timer::tick("Charge_Mixing", "mix_rho_real"); + + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + double* rhor_in=nullptr; double* rhor_out=nullptr; - if (PARAM.inp.nspin == 1) + if (nspin == 1) { rhor_in = chr->rho_save[0]; rhor_out = chr->rho[0]; @@ -312,17 +332,17 @@ void Charge_Mixing::mix_rho_real(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhor_out); } - else if (PARAM.inp.nspin == 2) + else if (nspin == 2) { // magnetic density double *rho_mag = nullptr; double *rho_mag_save = nullptr; const int nrxx = this->rhopw->nrxx; // allocate rho_mag[is*nnrx] and rho_mag_save[is*nnrx] - rho_mag = new double[nrxx * PARAM.inp.nspin]; - rho_mag_save = new double[nrxx * PARAM.inp.nspin]; - ModuleBase::GlobalFunc::ZEROS(rho_mag, nrxx * PARAM.inp.nspin); - ModuleBase::GlobalFunc::ZEROS(rho_mag_save, nrxx * PARAM.inp.nspin); + rho_mag = new double[nrxx * nspin]; + rho_mag_save = new double[nrxx * nspin]; + ModuleBase::GlobalFunc::ZEROS(rho_mag, nrxx * nspin); + ModuleBase::GlobalFunc::ZEROS(rho_mag_save, nrxx * nspin); // get rho_mag[is*nnrx] and rho_mag_save[is*nnrx] for (int ir = 0; ir < nrxx; ir++) { @@ -362,7 +382,7 @@ void Charge_Mixing::mix_rho_real(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhor_out); // get new rho[is][nrxx] from rho_mag[is*nrxx] - for (int is = 0; is < PARAM.inp.nspin; is++) + for (int is = 0; is < nspin; is++) { ModuleBase::GlobalFunc::ZEROS(chr->rho[is], nrxx); //ModuleBase::GlobalFunc::ZEROS(rho_save[is], nrxx); @@ -376,7 +396,7 @@ void Charge_Mixing::mix_rho_real(Charge* chr) delete[] rho_mag; delete[] rho_mag_save; } - else if (PARAM.inp.nspin == 4 && PARAM.inp.mixing_angle <= 0) + else if (nspin == 4 && PARAM.inp.mixing_angle <= 0) { // normal broyden mixing for {rho, mx, my, mz} rhor_in = chr->rho_save[0]; @@ -407,7 +427,7 @@ void Charge_Mixing::mix_rho_real(Charge* chr) this->mixing->cal_coef(this->rho_mdata, inner_product); this->mixing->mix_data(this->rho_mdata, rhor_out); } - else if (PARAM.inp.nspin == 4 && PARAM.inp.mixing_angle > 0) + else if (nspin == 4 && PARAM.inp.mixing_angle > 0) { // special broyden mixing for {rho, |m|} proposed by J. Phys. Soc. Jpn. 82 (2013) 114706 // here only consider the case of mixing_angle = 1, which mean only change |m| and keep angle fixed @@ -494,6 +514,8 @@ void Charge_Mixing::mix_rho_real(Charge* chr) this->mixing->mix_data(this->tau_mdata, taur_out); } + ModuleBase::timer::tick("Charge_Mixing", "mix_rho_real"); + return; } @@ -502,10 +524,13 @@ void Charge_Mixing::mix_rho(Charge* chr) ModuleBase::TITLE("Charge_Mixing", "mix_rho"); ModuleBase::timer::tick("Charge_Mixing", "mix_rho"); + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + // the charge before mixing. const int nrxx = chr->rhopw->nrxx; - std::vector rho123(PARAM.inp.nspin * nrxx); - for (int is = 0; is < PARAM.inp.nspin; ++is) + std::vector rho123(nspin * nrxx); + for (int is = 0; is < nspin; ++is) { if (is == 0 || is == 3 || !PARAM.globalv.domag_z) { @@ -522,8 +547,8 @@ void Charge_Mixing::mix_rho(Charge* chr) std::vector kin_r123; if ((XC_Functional::get_ked_flag()) && mixing_tau) { - kin_r123.resize(PARAM.inp.nspin * nrxx); - for (int is = 0; is < PARAM.inp.nspin; ++is) + kin_r123.resize(nspin * nrxx); + for (int is = 0; is < nspin; ++is) { double* kin_r123_is = kin_r123.data() + is * nrxx; #ifdef _OPENMP @@ -548,7 +573,7 @@ void Charge_Mixing::mix_rho(Charge* chr) // mohan add 2012-06-05 // rho_save is the charge before mixing - for (int is = 0; is < PARAM.inp.nspin; ++is) + for (int is = 0; is < nspin; ++is) { if (is == 0 || is == 3 || !PARAM.globalv.domag_z) { @@ -565,7 +590,7 @@ void Charge_Mixing::mix_rho(Charge* chr) if ((XC_Functional::get_ked_flag()) && mixing_tau) { - for (int is = 0; is < PARAM.inp.nspin; ++is) + for (int is = 0; is < nspin; ++is) { double* kin_r123_is = kin_r123.data() + is * nrxx; #ifdef _OPENMP From 91b9ceb20c220b092a075e361aea585b7b751f61 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 15:59:41 +0800 Subject: [PATCH 04/18] fix problems --- .../module_charge/charge_mixing_preconditioner.cpp | 14 +++++++------- .../module_charge/charge_mixing_residual.cpp | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/source/source_estate/module_charge/charge_mixing_preconditioner.cpp b/source/source_estate/module_charge/charge_mixing_preconditioner.cpp index 165b967ec8..00b5d42038 100644 --- a/source/source_estate/module_charge/charge_mixing_preconditioner.cpp +++ b/source/source_estate/module_charge/charge_mixing_preconditioner.cpp @@ -57,15 +57,16 @@ void Charge_Mixing::Kerker_screen_recip(std::complex* drhog) } gg0 = std::pow(fac * 0.529177 / *this->tpiba, 2); + + const double gg0_amin = this->mixing_gg0_min / amin; + #ifdef _OPENMP #pragma omp parallel for schedule(static, 512) #endif - const double gg0_amin = this->mixing_gg0_min / amin; - for (int ig = 0; ig < this->rhopw->npw; ++ig) { double gg = this->rhopw->gg[ig]; - double filter_g = std::max(gg / (gg + gg0), gg0_min); + double filter_g = std::max(gg / (gg + gg0), gg0_amin); drhog[is_idx + ig] *= filter_g; } } @@ -138,13 +139,12 @@ void Charge_Mixing::Kerker_screen_real(double* drhor) } gg0 = std::pow(fac * 0.529177 / *this->tpiba, 2); -#ifdef _OPENMP -#pragma omp parallel for schedule(static, 512) -#endif const int is_idx = is * this->rhopw->npw; const double gg0_amin = this->mixing_gg0_min / amin; - +#ifdef _OPENMP +#pragma omp parallel for schedule(static, 512) +#endif for (int ig = 0; ig < this->rhopw->npw; ig++) { double gg = this->rhopw->gg[ig]; diff --git a/source/source_estate/module_charge/charge_mixing_residual.cpp b/source/source_estate/module_charge/charge_mixing_residual.cpp index 8862ede35e..bbc7c7b6be 100644 --- a/source/source_estate/module_charge/charge_mixing_residual.cpp +++ b/source/source_estate/module_charge/charge_mixing_residual.cpp @@ -31,10 +31,9 @@ double Charge_Mixing::get_drho(Charge* chr, const double nelec) #endif for (int is = 0; is < nspin; ++is) { - const int is_idx = is * this->rhopw->npw; for (int ig = 0; ig < this->rhopw->npw; ig++) { - drhog[is_idx + ig] = chr->rhog[is][ig] - chr->rhog_save[is][ig]; + drhog[is * this->rhopw->npw + ig] = chr->rhog[is][ig] - chr->rhog_save[is][ig]; } } From a7f629459de6162a1252278c9dc9f74a48eff8d1 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 16:09:03 +0800 Subject: [PATCH 05/18] add timers --- .../source_cell/module_symmetry/symm_rho.cpp | 4 ++ .../module_charge/symmetry_rho.cpp | 57 ++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/source/source_cell/module_symmetry/symm_rho.cpp b/source/source_cell/module_symmetry/symm_rho.cpp index 754279ffbc..28ae30d5b4 100644 --- a/source/source_cell/module_symmetry/symm_rho.cpp +++ b/source/source_cell/module_symmetry/symm_rho.cpp @@ -9,6 +9,10 @@ void Symmetry::rho_symmetry( double *rho, { ModuleBase::timer::tick("Symmetry","rho_symmetry"); + assert(nr1>0); + assert(nr2>0); + assert(nr3>0); + // allocate flag for each FFT grid. bool* symflag = new bool[nr1 * nr2 * nr3]; for (int i=0; ireal2recip(CHR.rho[spin_now], CHR.rhog[spin_now]); - psymmg(CHR.rhog[spin_now], rho_basis, symm); // need to modify - rho_basis->recip2real(CHR.rhog[spin_now], CHR.rho[spin_now]); - if (XC_Functional::get_ked_flag() || CHR.cal_elf) + rho_basis->real2recip(chr.rho[spin_now], chr.rhog[spin_now]); + + psymmg(chr.rhog[spin_now], rho_basis, symm); // need to modify + + rho_basis->recip2real(chr.rhog[spin_now], chr.rho[spin_now]); + + if (XC_Functional::get_ked_flag() || chr.cal_elf) { // Use std::vector to manage kin_g instead of raw pointer - std::vector> kin_g(CHR.ngmc); - rho_basis->real2recip(CHR.kin_r[spin_now], kin_g.data()); + std::vector> kin_g(chr.ngmc); + rho_basis->real2recip(chr.kin_r[spin_now], kin_g.data()); psymmg(kin_g.data(), rho_basis, symm); - rho_basis->recip2real(kin_g.data(), CHR.kin_r[spin_now]); - } + rho_basis->recip2real(kin_g.data(), chr.kin_r[spin_now]); } + + ModuleBase::timer::tick("Symmetry_rho","begin"); return; } @@ -59,6 +67,10 @@ void Symmetry_rho::begin(const int& spin_now, { return; } + + ModuleBase::TITLE("Symmetry_rho", "begin"); + ModuleBase::timer::tick("Symmetry_rho","begin"); + // both parallel and serial // if(symm.nrot==symm.nrotk) //pure point-group, do rho_symm in real space // { @@ -81,6 +93,8 @@ void Symmetry_rho::begin(const int& spin_now, rho_basis->recip2real(kin_g.data(), kin_r[spin_now]); } } + + ModuleBase::timer::tick("Symmetry_rho","begin"); return; } @@ -89,8 +103,11 @@ void Symmetry_rho::psymm(double* rho_part, Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const { + ModuleBase::TITLE("Symmetry_rho", "psymm"); + ModuleBase::timer::tick("Symmetry_rho","psymm"); + #ifdef __MPI - // (1) reduce all rho from the first pool. + // reduce all rho from the first pool. std::vector rhotot; if (GlobalV::MY_RANK == 0) { @@ -99,7 +116,6 @@ void Symmetry_rho::psymm(double* rho_part, } Pgrid.reduce(rhotot.data(), rho_part, false); - // (2) if (GlobalV::MY_RANK == 0) { symm.rho_symmetry(rhotot.data(), rho_basis->nx, rho_basis->ny, rho_basis->nz); @@ -126,8 +142,9 @@ void Symmetry_rho::psymm(double* rho_part, #ifdef __MPI } - // (3) -Pgrid.bcast(rhotot.data(), rho_part, GlobalV::MY_RANK); + Pgrid.bcast(rhotot.data(), rho_part, GlobalV::MY_RANK); #endif + + ModuleBase::timer::tick("Symmetry_rho","psymm"); return; -} \ No newline at end of file +} From 9f957336051341586992c1251afe3b2a846ffb79 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 16:34:51 +0800 Subject: [PATCH 06/18] small fix --- source/source_esolver/esolver_of_tool.cpp | 2 +- .../module_pot/potential_new.cpp | 56 +++++++++++-------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/source/source_esolver/esolver_of_tool.cpp b/source/source_esolver/esolver_of_tool.cpp index a71b128cf3..50c0ff6561 100644 --- a/source/source_esolver/esolver_of_tool.cpp +++ b/source/source_esolver/esolver_of_tool.cpp @@ -220,7 +220,7 @@ double ESolver_OF::cal_mu(double* pphi, double* pdEdphi, double nelec) * @brief Rotate and renormalize the direction |d>, * make it orthogonal to phi ( = 0), and = nelec */ -void ESolver_OF::adjust_direction() +void ESolver_OF::adjust_direction(void) { // filter the high frequency term in direction if of_full_pw = false if (!PARAM.inp.of_full_pw) diff --git a/source/source_estate/module_pot/potential_new.cpp b/source/source_estate/module_pot/potential_new.cpp index 3d958bc87e..b1968339aa 100644 --- a/source/source_estate/module_pot/potential_new.cpp +++ b/source/source_estate/module_pot/potential_new.cpp @@ -23,7 +23,8 @@ Potential::Potential(const ModulePW::PW_Basis* rho_basis_in, double* etxc_in, double* vtxc_in, VSep* vsep_cell_in) - : ucell_(ucell_in), vloc_(vloc_in), structure_factors_(structure_factors_in), solvent_(solvent_in), vsep_cell(vsep_cell_in), etxc_(etxc_in), + : ucell_(ucell_in), vloc_(vloc_in), structure_factors_(structure_factors_in), + solvent_(solvent_in), vsep_cell(vsep_cell_in), etxc_(etxc_in), vtxc_(vtxc_in) { this->rho_basis_ = rho_basis_in; @@ -81,7 +82,6 @@ void Potential::pot_register(const std::vector& components_list) { PotBase* tmp = this->get_pot_type(comp); this->components.push_back(tmp); - // GlobalV::ofs_running << "Successful completion of Potential's registration : " << comp << std::endl; } // after register, reset fixed_done to false @@ -93,8 +93,13 @@ void Potential::pot_register(const std::vector& components_list) void Potential::allocate() { ModuleBase::TITLE("Potential", "allocate"); - int nrxx = this->rho_basis_->nrxx; - int nrxx_smooth = this->rho_basis_smooth_->nrxx; + + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + + const int nrxx = this->rho_basis_->nrxx; + const int nrxx_smooth = this->rho_basis_smooth_->nrxx; + if (nrxx == 0) { return; @@ -107,39 +112,39 @@ void Potential::allocate() this->v_effective_fixed.resize(nrxx); ModuleBase::Memory::record("Pot::veff_fix", sizeof(double) * nrxx); - this->v_effective.create(PARAM.inp.nspin, nrxx); - ModuleBase::Memory::record("Pot::veff", sizeof(double) * PARAM.inp.nspin * nrxx); + this->v_effective.create(nspin, nrxx); + ModuleBase::Memory::record("Pot::veff", sizeof(double) * nspin * nrxx); - this->veff_smooth.create(PARAM.inp.nspin, nrxx_smooth); - ModuleBase::Memory::record("Pot::veff_smooth", sizeof(double) * PARAM.inp.nspin * nrxx_smooth); + this->veff_smooth.create(nspin, nrxx_smooth); + ModuleBase::Memory::record("Pot::veff_smooth", sizeof(double) * nspin * nrxx_smooth); if (XC_Functional::get_ked_flag()) { - this->vofk_effective.create(PARAM.inp.nspin, nrxx); - ModuleBase::Memory::record("Pot::vofk", sizeof(double) * PARAM.inp.nspin * nrxx); + this->vofk_effective.create(nspin, nrxx); + ModuleBase::Memory::record("Pot::vofk", sizeof(double) * nspin * nrxx); - this->vofk_smooth.create(PARAM.inp.nspin, nrxx_smooth); - ModuleBase::Memory::record("Pot::vofk_smooth", sizeof(double) * PARAM.inp.nspin * nrxx_smooth); + this->vofk_smooth.create(nspin, nrxx_smooth); + ModuleBase::Memory::record("Pot::vofk_smooth", sizeof(double) * nspin * nrxx_smooth); } if (use_gpu_) { if (PARAM.globalv.has_float_data) { - resmem_sd_op()(s_veff_smooth, PARAM.inp.nspin * nrxx_smooth); - resmem_sd_op()(s_vofk_smooth, PARAM.inp.nspin * nrxx_smooth); + resmem_sd_op()(s_veff_smooth, nspin * nrxx_smooth); + resmem_sd_op()(s_vofk_smooth, nspin * nrxx_smooth); } if (PARAM.globalv.has_double_data) { - resmem_dd_op()(d_veff_smooth, PARAM.inp.nspin * nrxx_smooth); - resmem_dd_op()(d_vofk_smooth, PARAM.inp.nspin * nrxx_smooth); + resmem_dd_op()(d_veff_smooth, nspin * nrxx_smooth); + resmem_dd_op()(d_vofk_smooth, nspin * nrxx_smooth); } } else { if (PARAM.globalv.has_float_data) { - resmem_sh_op()(s_veff_smooth, PARAM.inp.nspin * nrxx_smooth, "POT::sveff_smooth"); - resmem_sh_op()(s_vofk_smooth, PARAM.inp.nspin * nrxx_smooth, "POT::svofk_smooth"); + resmem_sh_op()(s_veff_smooth, nspin * nrxx_smooth, "POT::sveff_smooth"); + resmem_sh_op()(s_vofk_smooth, nspin * nrxx_smooth, "POT::svofk_smooth"); } if (PARAM.globalv.has_double_data) { @@ -273,20 +278,23 @@ void Potential::get_vnew(const Charge* chg, ModuleBase::matrix& vnew) return; } -void Potential::interpolate_vrs() +void Potential::interpolate_vrs(void) { ModuleBase::TITLE("Potential", "interpolate_vrs"); ModuleBase::timer::tick("Potential", "interpolate_vrs"); - if ( PARAM.globalv.double_grid) + const int nspin = PARAM.inp.nspin; + assert(nspin==1 || nspin==2 || nspin==4); + + if (PARAM.globalv.double_grid) { if (rho_basis_->gamma_only != rho_basis_smooth_->gamma_only) { ModuleBase::WARNING_QUIT("Potential::interpolate_vrs", "gamma_only is not consistent"); } - ModuleBase::ComplexMatrix vrs(PARAM.inp.nspin, rho_basis_->npw); - for (int is = 0; is < PARAM.inp.nspin; is++) + ModuleBase::ComplexMatrix vrs(nspin, rho_basis_->npw); + for (int is = 0; is < nspin; is++) { rho_basis_->real2recip(&v_effective(is, 0), &vrs(is, 0)); rho_basis_smooth_->recip2real(&vrs(is, 0), &veff_smooth(is, 0)); @@ -294,8 +302,8 @@ void Potential::interpolate_vrs() if (XC_Functional::get_ked_flag()) { - ModuleBase::ComplexMatrix vrs_ofk(PARAM.inp.nspin, rho_basis_->npw); - for (int is = 0; is < PARAM.inp.nspin; is++) + ModuleBase::ComplexMatrix vrs_ofk(nspin, rho_basis_->npw); + for (int is = 0; is < nspin; is++) { rho_basis_->real2recip(&vofk_effective(is, 0), &vrs_ofk(is, 0)); rho_basis_smooth_->recip2real(&vrs_ofk(is, 0), &vofk_smooth(is, 0)); From f7f156dffb062f6a74ccde1f7f7def25ae753ce0 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Fri, 26 Sep 2025 16:41:40 +0800 Subject: [PATCH 07/18] fix a potential memory leak --- source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp b/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp index 95d076040e..5eedae9259 100644 --- a/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp +++ b/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp @@ -180,6 +180,7 @@ class PW_BASIS_K_GPU_TEST : public ::testing::Test } void TearDown() override { + delete[] kvec_d; // mohan add 20250926 delete[] h_rhog; delete[] h_rhogout; delete[] h_rhor; From 94e578026ad485e2fb3251d4873295702772a6f3 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 10:17:55 +0800 Subject: [PATCH 08/18] fix bug --- source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp b/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp index 5eedae9259..9abaefab44 100644 --- a/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp +++ b/source/source_basis/module_pw/test_gpu/pw_basis_C2C.cpp @@ -44,9 +44,6 @@ class PW_BASIS_K_GPU_TEST : public ::testing::Test int distribution_type = 1; bool xprime = false; const int nks = 1; - ModuleBase::Vector3* kvec_d; - kvec_d = new ModuleBase::Vector3[nks]; - kvec_d[0].set(0, 0, 0); // init const int mypool = 0; const int key = 1; @@ -180,7 +177,6 @@ class PW_BASIS_K_GPU_TEST : public ::testing::Test } void TearDown() override { - delete[] kvec_d; // mohan add 20250926 delete[] h_rhog; delete[] h_rhogout; delete[] h_rhor; From 8616ac3934b3a62ca406bbd0f6fa400227662e8d Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 13:52:37 +0800 Subject: [PATCH 09/18] add ctrl_output_pw files but cannot run now --- source/source_esolver/esolver_ks_pw.cpp | 305 +------------------- source/source_io/ctrl_output_pw.cpp | 356 ++++++++++++++++++++++++ source/source_io/ctrl_output_pw.h | 19 ++ 3 files changed, 380 insertions(+), 300 deletions(-) create mode 100644 source/source_io/ctrl_output_pw.cpp create mode 100644 source/source_io/ctrl_output_pw.h diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 2c342bfa13..ac522b495c 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -634,54 +634,6 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } - //---------------------------------------------------------- - // 3) Print out electronic wavefunctions in pw basis - // we only print information every few ionic steps - //---------------------------------------------------------- - - // if istep_in = -1, istep will not appear in file name - // if iter_in = -1, iter will not appear in file name - int istep_in = -1; - int iter_in = -1; - bool out_wfc_flag = false; - if (PARAM.inp.out_freq_ion>0) // default value of out_freq_ion is 0 - { - if (istep % PARAM.inp.out_freq_ion == 0) - { - if(iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) - { - istep_in = istep; - iter_in = iter; - out_wfc_flag = true; - } - } - } - else if(iter == PARAM.inp.scf_nmax || conv_esolver) - { - out_wfc_flag = true; - } - - - if (out_wfc_flag) - { - ModuleIO::write_wfc_pw(istep_in, iter_in, - GlobalV::KPAR, - GlobalV::MY_POOL, - GlobalV::MY_RANK, - PARAM.inp.nbands, - PARAM.inp.nspin, - PARAM.globalv.npol, - GlobalV::RANK_IN_POOL, - GlobalV::NPROC_IN_POOL, - PARAM.inp.out_wfc_pw, - PARAM.inp.ecutwfc, - PARAM.globalv.global_out_dir, - this->psi[0], - this->kv, - this->pw_wfc, - GlobalV::ofs_running); - } - //---------------------------------------------------------- // 4) check if oscillate for delta_spin method //---------------------------------------------------------- @@ -699,6 +651,8 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } } + + ctrl_iter_pw(); } template @@ -733,144 +687,7 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->psi[0].size()); } - //---------------------------------------------------------- - //! 4) Compute density of states (DOS) - //---------------------------------------------------------- - if (PARAM.inp.out_dos) - { - bool out_dos_tmp = false; - - int istep_in = -1; - - // default value of out_freq_ion is 0 - if(PARAM.inp.out_freq_ion==0) - { - out_dos_tmp = true; - } - else if (PARAM.inp.out_freq_ion>0) - { - if (istep % PARAM.inp.out_freq_ion == 0) - { - out_dos_tmp = true; - istep_in=istep; - } - else - { - out_dos_tmp = false; - } - } - else - { - out_dos_tmp = false; - } - - // the above is only valid for KSDFT, not SDFT - // this part needs update in the near future - if (PARAM.inp.esolver_type == "sdft") - { - out_dos_tmp = false; - } - - if(out_dos_tmp) - { - ModuleIO::write_dos_pw(ucell, - this->pelec->ekb, - this->pelec->wg, - this->kv, - PARAM.inp.nbands, - istep_in, - this->pelec->eferm, - PARAM.inp.dos_edelta_ev, - PARAM.inp.dos_scale, - PARAM.inp.dos_sigma, - GlobalV::ofs_running); - } - } - - //------------------------------------------------------------------ - // 5) calculate band-decomposed (partial) charge density in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.out_pchg.size() > 0) - { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); - - ModuleIO::get_pchg_pw(PARAM.inp.out_pchg, - this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, - this->chr.ngmc, - &ucell, - this->__kspw_psi, - this->pw_rhod, - this->pw_wfc, - this->ctx, - this->Pgrid, - PARAM.globalv.global_out_dir, - PARAM.inp.if_separate_k, - this->kv, - GlobalV::KPAR, - GlobalV::MY_POOL, - &this->chr); - } - - //------------------------------------------------------------------ - //! 6) calculate Wannier functions in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90) - { - std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation"); - toWannier90_PW wan(PARAM.inp.out_wannier_mmn, - PARAM.inp.out_wannier_amn, - PARAM.inp.out_wannier_unk, - PARAM.inp.out_wannier_eig, - PARAM.inp.out_wannier_wvfn_formatted, - PARAM.inp.nnkpfile, - PARAM.inp.wannier_spin); - wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, this->pelec->ekb, this->pw_wfc, this->pw_big, this->kv, this->psi); - std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); - } - - //------------------------------------------------------------------ - //! 7) calculate Berry phase polarization in pw basis - //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) - { - std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); - berryphase bp; - bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, this->kv); - std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); - } - - //------------------------------------------------------------------ - // 8) write spin constrian results in pw basis - // spin constrain calculations, write atomic magnetization and magnetic force. - //------------------------------------------------------------------ - if (PARAM.inp.sc_mag_switch) - { - spinconstrain::SpinConstrain>& sc - = spinconstrain::SpinConstrain>::getScInstance(); - sc.cal_mi_pw(); - sc.print_Mag_Force(GlobalV::ofs_running); - } - - //------------------------------------------------------------------ - // 9) write onsite occupations for charge and magnetizations - //------------------------------------------------------------------ - if (PARAM.inp.onsite_radius > 0) - { // float type has not been implemented - auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(this->kspw_psi), - this->pelec->wg); - } + ModuleIO::ctrl_scf_pw(); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -954,121 +771,9 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) //---------------------------------------------------------- ESolver_KS::after_all_runners(ucell); - //---------------------------------------------------------- - //! 2) Compute LDOS - //---------------------------------------------------------- - if (PARAM.inp.out_ldos[0]) - { - ModuleIO::cal_ldos_pw(reinterpret_cast>*>(this->pelec), - this->psi[0], - this->Pgrid, - ucell); - } - - //---------------------------------------------------------- - //! 3) Calculate the spillage value, - //! which are used to generate numerical atomic orbitals - //---------------------------------------------------------- - if (PARAM.inp.basis_type == "pw" && PARAM.inp.out_spillage) - { - // ! Print out overlap matrices - if (PARAM.inp.out_spillage <= 2) - { - for (int i = 0; i < PARAM.inp.bessel_nao_rcuts.size(); i++) - { - if (GlobalV::MY_RANK == 0) - { - std::cout << "update value: bessel_nao_rcut <- " << std::fixed << PARAM.inp.bessel_nao_rcuts[i] - << " a.u." << std::endl; - } - Numerical_Basis numerical_basis; - numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc, ucell, i); - } - ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); - } - } - - //---------------------------------------------------------- - //! 4) Print out electronic wave functions in real space - //---------------------------------------------------------- - if (PARAM.inp.out_wfc_norm.size() > 0 || PARAM.inp.out_wfc_re_im.size() > 0) - { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") - { - delete reinterpret_cast, Device>*>(this->__kspw_psi); - } - - // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); - - ModuleIO::get_wf_pw(PARAM.inp.out_wfc_norm, - PARAM.inp.out_wfc_re_im, - this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, - &ucell, - this->__kspw_psi, - this->pw_wfc, - this->ctx, - this->Pgrid, - PARAM.globalv.global_out_dir, - this->kv, - GlobalV::KPAR, - GlobalV::MY_POOL); - } - - //---------------------------------------------------------- - //! 5) Use Kubo-Greenwood method to compute conductivities - //---------------------------------------------------------- - if (PARAM.inp.cal_cond) - { - EleCond elec_cond(&ucell, &this->kv, this->pelec, this->pw_wfc, this->kspw_psi, &this->ppcell); - elec_cond.KG(PARAM.inp.cond_smear, - PARAM.inp.cond_fwhm, - PARAM.inp.cond_wcut, - PARAM.inp.cond_dw, - PARAM.inp.cond_dt, - PARAM.inp.cond_nonlocal, - this->pelec->wg); - } + ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_big, this->pw_rhod, + this->chr, this->solvent, this->Pgrid, istep); -#ifdef __MLALGO - //---------------------------------------------------------- - //! 7) generate training data for ML-KEDF - //---------------------------------------------------------- - if (PARAM.inp.of_ml_gene_data == 1) - { - this->pelec->pot->update_from_charge(&this->chr, &ucell); - - ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc; - write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx, - PARAM.inp.nelec, - PARAM.inp.of_tf_weight, - PARAM.inp.of_vw_weight, - PARAM.inp.of_ml_chi_p, - PARAM.inp.of_ml_chi_q, - PARAM.inp.of_ml_chi_xi, - PARAM.inp.of_ml_chi_pnl, - PARAM.inp.of_ml_chi_qnl, - PARAM.inp.of_ml_nkernel, - PARAM.inp.of_ml_kernel, - PARAM.inp.of_ml_kernel_scaling, - PARAM.inp.of_ml_yukawa_alpha, - PARAM.inp.of_ml_kernel_file, - ucell.omega, - this->pw_rho); - - write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - this->kspw_psi, - this->pelec, - this->pw_wfc, - this->pw_rho, - ucell, - this->pelec->pot->get_effective_v(0)); - } -#endif } template class ESolver_KS_PW, base_device::DEVICE_CPU>; diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp new file mode 100644 index 0000000000..6563d3e6c8 --- /dev/null +++ b/source/source_io/ctrl_output_pw.cpp @@ -0,0 +1,356 @@ +/* +#include "source_io/ctrl_output_fp.h" // use ctrl_output_fp() +#include "source_estate/module_charge/symmetry_rho.h" // use Symmetry_rho +#include "source_io/write_elecstat_pot.h" // use write_elecstat_pot +#include "source_io/write_elf.h" +#include "cube_io.h" // use write_vdata_palgrid +#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional + +#ifdef USE_LIBXC +#include "source_io/write_libxc_r.h" +#endif +*/ + +namespace ModuleIO +{ + +void ctrl_iter_pw() +{ + ModuleBase::TITLE("ModuleIO", "ctrl_iter_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_iter_pw"); + //---------------------------------------------------------- + // 3) Print out electronic wavefunctions in pw basis + // we only print information every few ionic steps + //---------------------------------------------------------- + + // if istep_in = -1, istep will not appear in file name + // if iter_in = -1, iter will not appear in file name + int istep_in = -1; + int iter_in = -1; + bool out_wfc_flag = false; + if (PARAM.inp.out_freq_ion>0) // default value of out_freq_ion is 0 + { + if (istep % PARAM.inp.out_freq_ion == 0) + { + if(iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) + { + istep_in = istep; + iter_in = iter; + out_wfc_flag = true; + } + } + } + else if(iter == PARAM.inp.scf_nmax || conv_esolver) + { + out_wfc_flag = true; + } + + if (out_wfc_flag) + { + ModuleIO::write_wfc_pw(istep_in, iter_in, + GlobalV::KPAR, + GlobalV::MY_POOL, + GlobalV::MY_RANK, + PARAM.inp.nbands, + PARAM.inp.nspin, + PARAM.globalv.npol, + GlobalV::RANK_IN_POOL, + GlobalV::NPROC_IN_POOL, + PARAM.inp.out_wfc_pw, + PARAM.inp.ecutwfc, + PARAM.globalv.global_out_dir, + this->psi[0], + this->kv, + this->pw_wfc, + GlobalV::ofs_running); + } + + ModuleBase::timer::tick("ModuleIO", "ctrl_iter_pw"); + return; +} + + +void ctrl_scf_pw() +{ + ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + + //---------------------------------------------------------- + //! 4) Compute density of states (DOS) + //---------------------------------------------------------- + if (PARAM.inp.out_dos) + { + bool out_dos_tmp = false; + + int istep_in = -1; + + // default value of out_freq_ion is 0 + if(PARAM.inp.out_freq_ion==0) + { + out_dos_tmp = true; + } + else if (PARAM.inp.out_freq_ion>0) + { + if (istep % PARAM.inp.out_freq_ion == 0) + { + out_dos_tmp = true; + istep_in=istep; + } + else + { + out_dos_tmp = false; + } + } + else + { + out_dos_tmp = false; + } + + // the above is only valid for KSDFT, not SDFT + // this part needs update in the near future + if (PARAM.inp.esolver_type == "sdft") + { + out_dos_tmp = false; + } + + if(out_dos_tmp) + { + ModuleIO::write_dos_pw(ucell, + this->pelec->ekb, + this->pelec->wg, + this->kv, + PARAM.inp.nbands, + istep_in, + this->pelec->eferm, + PARAM.inp.dos_edelta_ev, + PARAM.inp.dos_scale, + PARAM.inp.dos_sigma, + GlobalV::ofs_running); + } + } + + + //------------------------------------------------------------------ + // 5) calculate band-decomposed (partial) charge density in pw basis + //------------------------------------------------------------------ + if (PARAM.inp.out_pchg.size() > 0) + { + if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") + { + delete reinterpret_cast, Device>*>(this->__kspw_psi); + } + + // Refresh __kspw_psi + this->__kspw_psi = PARAM.inp.precision == "single" + ? new psi::Psi, Device>(this->kspw_psi[0]) + : reinterpret_cast, Device>*>(this->kspw_psi); + + ModuleIO::get_pchg_pw(PARAM.inp.out_pchg, + this->kspw_psi->get_nbands(), + PARAM.inp.nspin, + this->pw_rhod->nxyz, + this->chr.ngmc, + &ucell, + this->__kspw_psi, + this->pw_rhod, + this->pw_wfc, + this->ctx, + this->Pgrid, + PARAM.globalv.global_out_dir, + PARAM.inp.if_separate_k, + this->kv, + GlobalV::KPAR, + GlobalV::MY_POOL, + &this->chr); + } + + + //------------------------------------------------------------------ + //! 6) calculate Wannier functions in pw basis + //------------------------------------------------------------------ + if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90) + { + std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation"); + toWannier90_PW wan(PARAM.inp.out_wannier_mmn, + PARAM.inp.out_wannier_amn, + PARAM.inp.out_wannier_unk, + PARAM.inp.out_wannier_eig, + PARAM.inp.out_wannier_wvfn_formatted, + PARAM.inp.nnkpfile, + PARAM.inp.wannier_spin); + wan.set_tpiba_omega(ucell.tpiba, ucell.omega); + wan.calculate(ucell, this->pelec->ekb, this->pw_wfc, this->pw_big, this->kv, this->psi); + std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); + } + + + //------------------------------------------------------------------ + //! 7) calculate Berry phase polarization in pw basis + //------------------------------------------------------------------ + if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) + { + std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); + berryphase bp; + bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, this->kv); + std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); + } + + //------------------------------------------------------------------ + // 8) write spin constrian results in pw basis + // spin constrain calculations, write atomic magnetization and magnetic force. + //------------------------------------------------------------------ + if (PARAM.inp.sc_mag_switch) + { + spinconstrain::SpinConstrain>& sc + = spinconstrain::SpinConstrain>::getScInstance(); + sc.cal_mi_pw(); + sc.print_Mag_Force(GlobalV::ofs_running); + } + + //------------------------------------------------------------------ + // 9) write onsite occupations for charge and magnetizations + //------------------------------------------------------------------ + if (PARAM.inp.onsite_radius > 0) + { // float type has not been implemented + auto* onsite_p = projectors::OnsiteProjector::get_instance(); + onsite_p->cal_occupations(reinterpret_cast, Device>*>(this->kspw_psi), + this->pelec->wg); + } + + ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); + return; +} + + +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_Big* pw_big, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + surchem &solvent, + Parallel_Grid ¶_grid, + const int istep) +{ + ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw"); + ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); + + //---------------------------------------------------------- + //! 1) Compute LDOS + //---------------------------------------------------------- + if (PARAM.inp.out_ldos[0]) + { + ModuleIO::cal_ldos_pw(reinterpret_cast>*>(this->pelec), + this->psi[0], + this->Pgrid, + ucell); + } + + //---------------------------------------------------------- + //! 2) Calculate the spillage value, + //! which are used to generate numerical atomic orbitals + //---------------------------------------------------------- + if (PARAM.inp.basis_type == "pw" && PARAM.inp.out_spillage) + { + // ! Print out overlap matrices + if (PARAM.inp.out_spillage <= 2) + { + for (int i = 0; i < PARAM.inp.bessel_nao_rcuts.size(); i++) + { + if (GlobalV::MY_RANK == 0) + { + std::cout << "update value: bessel_nao_rcut <- " << std::fixed << PARAM.inp.bessel_nao_rcuts[i] + << " a.u." << std::endl; + } + Numerical_Basis numerical_basis; + numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc, ucell, i); + } + ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); + } + } + + //---------------------------------------------------------- + //! 3) Print out electronic wave functions in real space + //---------------------------------------------------------- + if (PARAM.inp.out_wfc_norm.size() > 0 || PARAM.inp.out_wfc_re_im.size() > 0) + { + if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") + { + delete reinterpret_cast, Device>*>(this->__kspw_psi); + } + + // Refresh __kspw_psi + this->__kspw_psi = PARAM.inp.precision == "single" + ? new psi::Psi, Device>(this->kspw_psi[0]) + : reinterpret_cast, Device>*>(this->kspw_psi); + + ModuleIO::get_wf_pw(PARAM.inp.out_wfc_norm, + PARAM.inp.out_wfc_re_im, + this->kspw_psi->get_nbands(), + PARAM.inp.nspin, + this->pw_rhod->nxyz, + &ucell, + this->__kspw_psi, + this->pw_wfc, + this->ctx, + this->Pgrid, + PARAM.globalv.global_out_dir, + this->kv, + GlobalV::KPAR, + GlobalV::MY_POOL); + } + + //---------------------------------------------------------- + //! 4) Use Kubo-Greenwood method to compute conductivities + //---------------------------------------------------------- + if (PARAM.inp.cal_cond) + { + EleCond elec_cond(&ucell, &this->kv, this->pelec, this->pw_wfc, this->kspw_psi, &this->ppcell); + elec_cond.KG(PARAM.inp.cond_smear, + PARAM.inp.cond_fwhm, + PARAM.inp.cond_wcut, + PARAM.inp.cond_dw, + PARAM.inp.cond_dt, + PARAM.inp.cond_nonlocal, + this->pelec->wg); + } + +#ifdef __MLALGO + //---------------------------------------------------------- + //! 7) generate training data for ML-KEDF + //---------------------------------------------------------- + if (PARAM.inp.of_ml_gene_data == 1) + { + this->pelec->pot->update_from_charge(&this->chr, &ucell); + + ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc; + write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx, + PARAM.inp.nelec, + PARAM.inp.of_tf_weight, + PARAM.inp.of_vw_weight, + PARAM.inp.of_ml_chi_p, + PARAM.inp.of_ml_chi_q, + PARAM.inp.of_ml_chi_xi, + PARAM.inp.of_ml_chi_pnl, + PARAM.inp.of_ml_chi_qnl, + PARAM.inp.of_ml_nkernel, + PARAM.inp.of_ml_kernel, + PARAM.inp.of_ml_kernel_scaling, + PARAM.inp.of_ml_yukawa_alpha, + PARAM.inp.of_ml_kernel_file, + ucell.omega, + this->pw_rho); + + write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, + this->kspw_psi, + this->pelec, + this->pw_wfc, + this->pw_rho, + ucell, + this->pelec->pot->get_effective_v(0)); + } +#endif + + ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); +} + +} // End ModuleIO diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h new file mode 100644 index 0000000000..8dc8520766 --- /dev/null +++ b/source/source_io/ctrl_output_pw.h @@ -0,0 +1,19 @@ +#ifndef CTRL_OUTPUT_PW_H +#define CTRL_OUTPUT_PW_H + +#include "source_estate/elecstate_lcao.h" + +namespace ModuleIO +{ + + void ctrl_output_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_Big* pw_big, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + surchem &solvent, + Parallel_Grid ¶_grid, + const int istep); + +} +#endif From bbe4c3f12301fd1b94d77ac77d11097ab59a30ee Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 14:26:35 +0800 Subject: [PATCH 10/18] add some interfaces in ctrl_output_pw --- source/Makefile.Objects | 1 + source/source_io/CMakeLists.txt | 1 + source/source_io/ctrl_output_pw.cpp | 238 +++++++++++++++------------- source/source_io/ctrl_output_pw.h | 29 +++- 4 files changed, 151 insertions(+), 118 deletions(-) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index db5585dd3a..8d4d9933a1 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -580,6 +580,7 @@ OBJS_IO=input_conv.o\ output_mat_sparse.o\ ctrl_output_lcao.o\ ctrl_output_fp.o\ + ctrl_output_pw.o\ para_json.o\ abacusjson.o\ general_info.o\ diff --git a/source/source_io/CMakeLists.txt b/source/source_io/CMakeLists.txt index ffa7598cd9..e511ffb35c 100644 --- a/source/source_io/CMakeLists.txt +++ b/source/source_io/CMakeLists.txt @@ -1,6 +1,7 @@ list(APPEND objects input_conv.cpp ctrl_output_fp.cpp + ctrl_output_pw.cpp bessel_basis.cpp cal_test.cpp cal_dos.cpp diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 6563d3e6c8..806b586c65 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -14,7 +14,14 @@ namespace ModuleIO { -void ctrl_iter_pw() +template +void ctrl_iter_pw(const int istep, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_iter_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_iter_pw"); @@ -28,11 +35,11 @@ void ctrl_iter_pw() int istep_in = -1; int iter_in = -1; bool out_wfc_flag = false; - if (PARAM.inp.out_freq_ion>0) // default value of out_freq_ion is 0 + if (inp.out_freq_ion>0) // default value of out_freq_ion is 0 { - if (istep % PARAM.inp.out_freq_ion == 0) + if (istep % inp.out_freq_ion == 0) { - if(iter % PARAM.inp.out_freq_elec == 0 || iter == PARAM.inp.scf_nmax || conv_esolver) + if(iter % inp.out_freq_elec == 0 || iter == inp.scf_nmax || conv_esolver) { istep_in = istep; iter_in = iter; @@ -40,7 +47,7 @@ void ctrl_iter_pw() } } } - else if(iter == PARAM.inp.scf_nmax || conv_esolver) + else if(iter == inp.scf_nmax || conv_esolver) { out_wfc_flag = true; } @@ -51,17 +58,17 @@ void ctrl_iter_pw() GlobalV::KPAR, GlobalV::MY_POOL, GlobalV::MY_RANK, - PARAM.inp.nbands, - PARAM.inp.nspin, + inp.nbands, + inp.nspin, PARAM.globalv.npol, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - PARAM.inp.out_wfc_pw, - PARAM.inp.ecutwfc, + inp.out_wfc_pw, + inp.ecutwfc, PARAM.globalv.global_out_dir, - this->psi[0], - this->kv, - this->pw_wfc, + psi[0], + kv, + pw_wfc, GlobalV::ofs_running); } @@ -70,7 +77,14 @@ void ctrl_iter_pw() } -void ctrl_scf_pw() +void ctrl_scf_pw(elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rhod, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); @@ -78,20 +92,20 @@ void ctrl_scf_pw() //---------------------------------------------------------- //! 4) Compute density of states (DOS) //---------------------------------------------------------- - if (PARAM.inp.out_dos) + if (inp.out_dos) { bool out_dos_tmp = false; int istep_in = -1; // default value of out_freq_ion is 0 - if(PARAM.inp.out_freq_ion==0) + if(inp.out_freq_ion==0) { out_dos_tmp = true; } - else if (PARAM.inp.out_freq_ion>0) + else if (inp.out_freq_ion>0) { - if (istep % PARAM.inp.out_freq_ion == 0) + if (istep % inp.out_freq_ion == 0) { out_dos_tmp = true; istep_in=istep; @@ -108,7 +122,7 @@ void ctrl_scf_pw() // the above is only valid for KSDFT, not SDFT // this part needs update in the near future - if (PARAM.inp.esolver_type == "sdft") + if (inp.esolver_type == "sdft") { out_dos_tmp = false; } @@ -116,15 +130,15 @@ void ctrl_scf_pw() if(out_dos_tmp) { ModuleIO::write_dos_pw(ucell, - this->pelec->ekb, - this->pelec->wg, - this->kv, - PARAM.inp.nbands, + pelec->ekb, + pelec->wg, + kv, + inp.nbands, istep_in, - this->pelec->eferm, - PARAM.inp.dos_edelta_ev, - PARAM.inp.dos_scale, - PARAM.inp.dos_sigma, + pelec->eferm, + inp.dos_edelta_ev, + inp.dos_scale, + inp.dos_sigma, GlobalV::ofs_running); } } @@ -133,53 +147,53 @@ void ctrl_scf_pw() //------------------------------------------------------------------ // 5) calculate band-decomposed (partial) charge density in pw basis //------------------------------------------------------------------ - if (PARAM.inp.out_pchg.size() > 0) + if (inp.out_pchg.size() > 0) { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") + if (__kspw_psi != nullptr && inp.precision == "single") { - delete reinterpret_cast, Device>*>(this->__kspw_psi); + delete reinterpret_cast, Device>*>(__kspw_psi); } // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); - - ModuleIO::get_pchg_pw(PARAM.inp.out_pchg, - this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, - this->chr.ngmc, + __kspw_psi = inp.precision == "single" + ? new psi::Psi, Device>(kspw_psi[0]) + : reinterpret_cast, Device>*>(kspw_psi); + + ModuleIO::get_pchg_pw(inp.out_pchg, + kspw_psi->get_nbands(), + inp.nspin, + pw_rhod->nxyz, + chr.ngmc, &ucell, - this->__kspw_psi, - this->pw_rhod, - this->pw_wfc, + __kspw_psi, + pw_rhod, + pw_wfc, this->ctx, this->Pgrid, PARAM.globalv.global_out_dir, - PARAM.inp.if_separate_k, - this->kv, + inp.if_separate_k, + kv, GlobalV::KPAR, GlobalV::MY_POOL, - &this->chr); + &chr); } //------------------------------------------------------------------ //! 6) calculate Wannier functions in pw basis //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90) + if (inp.calculation == "nscf" && inp.towannier90) { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation"); - toWannier90_PW wan(PARAM.inp.out_wannier_mmn, - PARAM.inp.out_wannier_amn, - PARAM.inp.out_wannier_unk, - PARAM.inp.out_wannier_eig, - PARAM.inp.out_wannier_wvfn_formatted, - PARAM.inp.nnkpfile, - PARAM.inp.wannier_spin); + toWannier90_PW wan(inp.out_wannier_mmn, + inp.out_wannier_amn, + inp.out_wannier_unk, + inp.out_wannier_eig, + inp.out_wannier_wvfn_formatted, + inp.nnkpfile, + inp.wannier_spin); wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, this->pelec->ekb, this->pw_wfc, this->pw_big, this->kv, this->psi); + wan.calculate(ucell, pelec->ekb, this->pw_wfc, this->pw_big, kv, this->psi); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); } @@ -187,11 +201,11 @@ void ctrl_scf_pw() //------------------------------------------------------------------ //! 7) calculate Berry phase polarization in pw basis //------------------------------------------------------------------ - if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) + if (inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1) { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); berryphase bp; - bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, this->kv); + bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, kv); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); } @@ -199,7 +213,7 @@ void ctrl_scf_pw() // 8) write spin constrian results in pw basis // spin constrain calculations, write atomic magnetization and magnetic force. //------------------------------------------------------------------ - if (PARAM.inp.sc_mag_switch) + if (inp.sc_mag_switch) { spinconstrain::SpinConstrain>& sc = spinconstrain::SpinConstrain>::getScInstance(); @@ -210,11 +224,11 @@ void ctrl_scf_pw() //------------------------------------------------------------------ // 9) write onsite occupations for charge and magnetizations //------------------------------------------------------------------ - if (PARAM.inp.onsite_radius > 0) + if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); onsite_p->cal_occupations(reinterpret_cast, Device>*>(this->kspw_psi), - this->pelec->wg); + pelec->wg); } ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); @@ -224,12 +238,16 @@ void ctrl_scf_pw() void ctrl_runner_pw(UnitCell& ucell, elecstate::ElecState* pelec, - ModulePW::PW_Basis_Big* pw_big, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, - Charge &chr, - surchem &solvent, + Charge &chr, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + surchem &solvent, Parallel_Grid ¶_grid, - const int istep) + const int istep, + const Input_para& inp); { ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); @@ -237,9 +255,9 @@ void ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- //! 1) Compute LDOS //---------------------------------------------------------- - if (PARAM.inp.out_ldos[0]) + if (inp.out_ldos[0]) { - ModuleIO::cal_ldos_pw(reinterpret_cast>*>(this->pelec), + ModuleIO::cal_ldos_pw(reinterpret_cast>*>(pelec), this->psi[0], this->Pgrid, ucell); @@ -249,20 +267,20 @@ void ctrl_runner_pw(UnitCell& ucell, //! 2) Calculate the spillage value, //! which are used to generate numerical atomic orbitals //---------------------------------------------------------- - if (PARAM.inp.basis_type == "pw" && PARAM.inp.out_spillage) + if (inp.basis_type == "pw" && inp.out_spillage) { // ! Print out overlap matrices - if (PARAM.inp.out_spillage <= 2) + if (inp.out_spillage <= 2) { - for (int i = 0; i < PARAM.inp.bessel_nao_rcuts.size(); i++) + for (int i = 0; i < inp.bessel_nao_rcuts.size(); i++) { if (GlobalV::MY_RANK == 0) { - std::cout << "update value: bessel_nao_rcut <- " << std::fixed << PARAM.inp.bessel_nao_rcuts[i] + std::cout << "update value: bessel_nao_rcut <- " << std::fixed << inp.bessel_nao_rcuts[i] << " a.u." << std::endl; } Numerical_Basis numerical_basis; - numerical_basis.output_overlap(this->psi[0], this->sf, this->kv, this->pw_wfc, ucell, i); + numerical_basis.output_overlap(this->psi[0], this->sf, kv, pw_wfc, ucell, i); } ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); } @@ -271,30 +289,30 @@ void ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- //! 3) Print out electronic wave functions in real space //---------------------------------------------------------- - if (PARAM.inp.out_wfc_norm.size() > 0 || PARAM.inp.out_wfc_re_im.size() > 0) + if (inp.out_wfc_norm.size() > 0 || inp.out_wfc_re_im.size() > 0) { - if (this->__kspw_psi != nullptr && PARAM.inp.precision == "single") + if (__kspw_psi != nullptr && inp.precision == "single") { - delete reinterpret_cast, Device>*>(this->__kspw_psi); + delete reinterpret_cast, Device>*>(__kspw_psi); } // Refresh __kspw_psi - this->__kspw_psi = PARAM.inp.precision == "single" + __kspw_psi = inp.precision == "single" ? new psi::Psi, Device>(this->kspw_psi[0]) : reinterpret_cast, Device>*>(this->kspw_psi); - ModuleIO::get_wf_pw(PARAM.inp.out_wfc_norm, - PARAM.inp.out_wfc_re_im, + ModuleIO::get_wf_pw(inp.out_wfc_norm, + inp.out_wfc_re_im, this->kspw_psi->get_nbands(), - PARAM.inp.nspin, - this->pw_rhod->nxyz, + inp.nspin, + pw_rhod->nxyz, &ucell, - this->__kspw_psi, - this->pw_wfc, + __kspw_psi, + pw_wfc, this->ctx, this->Pgrid, PARAM.globalv.global_out_dir, - this->kv, + kv, GlobalV::KPAR, GlobalV::MY_POOL); } @@ -302,51 +320,51 @@ void ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- //! 4) Use Kubo-Greenwood method to compute conductivities //---------------------------------------------------------- - if (PARAM.inp.cal_cond) + if (inp.cal_cond) { - EleCond elec_cond(&ucell, &this->kv, this->pelec, this->pw_wfc, this->kspw_psi, &this->ppcell); - elec_cond.KG(PARAM.inp.cond_smear, - PARAM.inp.cond_fwhm, - PARAM.inp.cond_wcut, - PARAM.inp.cond_dw, - PARAM.inp.cond_dt, - PARAM.inp.cond_nonlocal, - this->pelec->wg); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, this->kspw_psi, &this->ppcell); + elec_cond.KG(inp.cond_smear, + inp.cond_fwhm, + inp.cond_wcut, + inp.cond_dw, + inp.cond_dt, + inp.cond_nonlocal, + pelec->wg); } #ifdef __MLALGO //---------------------------------------------------------- //! 7) generate training data for ML-KEDF //---------------------------------------------------------- - if (PARAM.inp.of_ml_gene_data == 1) + if (inp.of_ml_gene_data == 1) { - this->pelec->pot->update_from_charge(&this->chr, &ucell); + pelec->pot->update_from_charge(&chr, &ucell); ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc; - write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx, - PARAM.inp.nelec, - PARAM.inp.of_tf_weight, - PARAM.inp.of_vw_weight, - PARAM.inp.of_ml_chi_p, - PARAM.inp.of_ml_chi_q, - PARAM.inp.of_ml_chi_xi, - PARAM.inp.of_ml_chi_pnl, - PARAM.inp.of_ml_chi_qnl, - PARAM.inp.of_ml_nkernel, - PARAM.inp.of_ml_kernel, - PARAM.inp.of_ml_kernel_scaling, - PARAM.inp.of_ml_yukawa_alpha, - PARAM.inp.of_ml_kernel_file, + write_mlkedf_desc.cal_tool->set_para(chr.nrxx, + inp.nelec, + inp.of_tf_weight, + inp.of_vw_weight, + inp.of_ml_chi_p, + inp.of_ml_chi_q, + inp.of_ml_chi_xi, + inp.of_ml_chi_pnl, + inp.of_ml_chi_qnl, + inp.of_ml_nkernel, + inp.of_ml_kernel, + inp.of_ml_kernel_scaling, + inp.of_ml_yukawa_alpha, + inp.of_ml_kernel_file, ucell.omega, - this->pw_rho); + pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, this->kspw_psi, - this->pelec, - this->pw_wfc, - this->pw_rho, + pelec, + pw_wfc, + pw_rho, ucell, - this->pelec->pot->get_effective_v(0)); + pelec->pot->get_effective_v(0)); } #endif diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 8dc8520766..1502cd5f03 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -6,14 +6,27 @@ namespace ModuleIO { - void ctrl_output_pw(UnitCell& ucell, - elecstate::ElecState* pelec, - ModulePW::PW_Basis_Big* pw_big, - ModulePW::PW_Basis* pw_rhod, - Charge &chr, - surchem &solvent, - Parallel_Grid ¶_grid, - const int istep); +template +void ctrl_iter_pw(const int istep, + const int iter, + const double &conv_esolver, + psi::Psi, base_device::DEVICE_CPU>* psi, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const Input_para& inp); + +template +void ctrl_scf_pw(elecstate::ElecState* pelec); + +template +void ctrl_runner_pw(UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_Big* pw_big, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + surchem &solvent, + Parallel_Grid ¶_grid, + const int istep); } #endif From 77ea8006c7e13bbb7aca7b74c7095459946c42b9 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 15:15:49 +0800 Subject: [PATCH 11/18] keep fixing bugs --- source/source_esolver/esolver_ks_pw.cpp | 41 +++++---- source/source_io/ctrl_output_pw.cpp | 114 +++++++++++++++++++----- source/source_io/ctrl_output_pw.h | 37 ++++++-- 3 files changed, 145 insertions(+), 47 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index ac522b495c..dc2e1cfc47 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -14,30 +14,30 @@ #include "source_hsolver/diago_iter_assist.h" #include "source_hsolver/hsolver_pw.h" #include "source_hsolver/kernels/dngvd_op.h" -#include "source_io/berryphase.h" -#include "source_io/cal_ldos.h" -#include "source_io/get_pchg_pw.h" -#include "source_io/get_wf_pw.h" +//#include "source_io/berryphase.h" +//#include "source_io/cal_ldos.h" +//#include "source_io/get_pchg_pw.h" +//#include "source_io/get_wf_pw.h" #include "source_io/module_parameter/parameter.h" -#include "source_io/numerical_basis.h" -#include "source_io/numerical_descriptor.h" -#include "source_io/to_wannier90_pw.h" -#include "source_io/write_dos_pw.h" -#include "source_io/write_wfc_pw.h" +//#include "source_io/numerical_basis.h" +//#include "source_io/numerical_descriptor.h" +//#include "source_io/to_wannier90_pw.h" +//#include "source_io/write_dos_pw.h" +//#include "source_io/write_wfc_pw.h" #include "source_lcao/module_deltaspin/spin_constrain.h" #include "source_lcao/module_dftu/dftu.h" #include "source_pw/module_pwdft/VSep_in_pw.h" #include "source_pw/module_pwdft/elecond.h" #include "source_pw/module_pwdft/forces.h" #include "source_pw/module_pwdft/hamilt_pw.h" -#include "source_pw/module_pwdft/onsite_projector.h" +//#include "source_pw/module_pwdft/onsite_projector.h" #include "source_pw/module_pwdft/stress_pw.h" #include -#ifdef __MLALGO -#include "source_io/write_mlkedf_descriptors.h" -#endif +//#ifdef __MLALGO +//#include "source_io/write_mlkedf_descriptors.h" +//#endif #include #include @@ -48,6 +48,8 @@ #include +#include "source_io/ctrl_output_pw.h" // mohan add 20250927 + namespace ModuleESolver { @@ -652,7 +654,8 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } - ctrl_iter_pw(); + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + this->kv, this->pw_wfc, PARAM.inp); } template @@ -687,7 +690,9 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->psi[0].size()); } - ModuleIO::ctrl_scf_pw(); + ModuleIO::ctrl_scf_pw(this->pelec, this->chr, this->kv, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, + this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -771,8 +776,10 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) //---------------------------------------------------------- ESolver_KS::after_all_runners(ucell); - ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_big, this->pw_rhod, - this->chr, this->solvent, this->Pgrid, istep); + ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->chr, this->psi, + this->kspw_psi, this->__kspw_psi, this->sf, + this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); } diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 806b586c65..7d1aa48869 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -1,3 +1,8 @@ +#include "source_io/ctrl_output_pw.h" + +#include "source_io/write_wfc_pw.h" // use write_wfc_pw +#include "source_pw/module_pwdft/onsite_projector.h" // use projector + /* #include "source_io/ctrl_output_fp.h" // use ctrl_output_fp() #include "source_estate/module_charge/symmetry_rho.h" // use Symmetry_rho @@ -14,7 +19,6 @@ namespace ModuleIO { -template void ctrl_iter_pw(const int istep, const int iter, const double &conv_esolver, @@ -77,13 +81,19 @@ void ctrl_iter_pw(const int istep, } +template void ctrl_scf_pw(elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, + const Device* ctx, + const Parallel_Grid ¶_grid, const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_scf_pw"); @@ -121,7 +131,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, } // the above is only valid for KSDFT, not SDFT - // this part needs update in the near future + // Needs update in the near future if (inp.esolver_type == "sdft") { out_dos_tmp = false; @@ -168,8 +178,8 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, __kspw_psi, pw_rhod, pw_wfc, - this->ctx, - this->Pgrid, + ctx, + para_grid, PARAM.globalv.global_out_dir, inp.if_separate_k, kv, @@ -193,7 +203,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, inp.nnkpfile, inp.wannier_spin); wan.set_tpiba_omega(ucell.tpiba, ucell.omega); - wan.calculate(ucell, pelec->ekb, this->pw_wfc, this->pw_big, kv, this->psi); + wan.calculate(ucell, pelec->ekb, pw_wfc, pw_big, kv, psi); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation"); } @@ -205,7 +215,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, { std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization"); berryphase bp; - bp.Macroscopic_polarization(ucell, this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, kv); + bp.Macroscopic_polarization(ucell, pw_wfc->npwk_max, psi, pw_rho, pw_wfc, kv); std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization"); } @@ -227,7 +237,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, if (inp.onsite_radius > 0) { // float type has not been implemented auto* onsite_p = projectors::OnsiteProjector::get_instance(); - onsite_p->cal_occupations(reinterpret_cast, Device>*>(this->kspw_psi), + onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), pelec->wg); } @@ -235,19 +245,22 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, return; } - +template void ctrl_runner_pw(UnitCell& ucell, elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, + psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, surchem &solvent, + const Device* ctx, Parallel_Grid ¶_grid, - const int istep, - const Input_para& inp); + const Input_para& inp) { ModuleBase::TITLE("ModuleIO", "ctrl_runner_pw"); ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); @@ -258,9 +271,7 @@ void ctrl_runner_pw(UnitCell& ucell, if (inp.out_ldos[0]) { ModuleIO::cal_ldos_pw(reinterpret_cast>*>(pelec), - this->psi[0], - this->Pgrid, - ucell); + psi[0], para_grid, ucell); } //---------------------------------------------------------- @@ -280,7 +291,7 @@ void ctrl_runner_pw(UnitCell& ucell, << " a.u." << std::endl; } Numerical_Basis numerical_basis; - numerical_basis.output_overlap(this->psi[0], this->sf, kv, pw_wfc, ucell, i); + numerical_basis.output_overlap(psi[0], sf, kv, pw_wfc, ucell, i); } ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "BASIS OVERLAP (Q and S) GENERATION."); } @@ -298,19 +309,19 @@ void ctrl_runner_pw(UnitCell& ucell, // Refresh __kspw_psi __kspw_psi = inp.precision == "single" - ? new psi::Psi, Device>(this->kspw_psi[0]) - : reinterpret_cast, Device>*>(this->kspw_psi); + ? new psi::Psi, Device>(kspw_psi[0]) + : reinterpret_cast, Device>*>(kspw_psi); ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, - this->kspw_psi->get_nbands(), + kspw_psi->get_nbands(), inp.nspin, pw_rhod->nxyz, &ucell, __kspw_psi, pw_wfc, - this->ctx, - this->Pgrid, + ctx, + para_grid, PARAM.globalv.global_out_dir, kv, GlobalV::KPAR, @@ -322,7 +333,7 @@ void ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.cal_cond) { - EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, this->kspw_psi, &this->ppcell); + EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, inp.cond_wcut, @@ -359,7 +370,7 @@ void ctrl_runner_pw(UnitCell& ucell, pw_rho); write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir, - this->kspw_psi, + kspw_psi, pelec, pw_wfc, pw_rho, @@ -371,4 +382,65 @@ void ctrl_runner_pw(UnitCell& ucell, ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); } +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + +template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, + const base_device::DEVICE_CPU* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); + + } // End ModuleIO diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 1502cd5f03..010749c805 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -1,12 +1,12 @@ #ifndef CTRL_OUTPUT_PW_H #define CTRL_OUTPUT_PW_H -#include "source_estate/elecstate_lcao.h" +#include "source_psi/psi_init.h" // about psi +#include "source_estate/elecstate_lcao.h" // use pelec namespace ModuleIO { -template void ctrl_iter_pw(const int istep, const int iter, const double &conv_esolver, @@ -15,18 +15,37 @@ void ctrl_iter_pw(const int istep, const ModulePW::PW_Basis_K *pw_wfc, const Input_para& inp); -template -void ctrl_scf_pw(elecstate::ElecState* pelec); +template +void ctrl_scf_pw(elecstate::ElecState* pelec, + const Charge &chr, + const K_Vectors &kv, + const ModulePW::PW_Basis_K *pw_wfc, + const ModulePW::PW_Basis *pw_rho, + const ModulePW::PW_Basis *pw_rhod, + const ModulePW::PW_Basis_Big *pw_big, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + const Device* ctx, + const Parallel_Grid ¶_grid, + const Input_para& inp); -template +template void ctrl_runner_pw(UnitCell& ucell, elecstate::ElecState* pelec, - ModulePW::PW_Basis_Big* pw_big, - ModulePW::PW_Basis* pw_rhod, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, Charge &chr, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi* kspw_psi, + psi::Psi, Device>* __kspw_psi, + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, surchem &solvent, - Parallel_Grid ¶_grid, - const int istep); + const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); } #endif From 5364e44503ddc019d7e2e4e74ba43f09259aedf2 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 16:17:19 +0800 Subject: [PATCH 12/18] successfully compile the codes --- source/source_base/module_device/device.h | 12 +- source/source_esolver/esolver_ks_pw.cpp | 33 ++--- source/source_io/ctrl_output_pw.cpp | 161 +++++++++++++++++----- source/source_io/ctrl_output_pw.h | 12 +- 4 files changed, 146 insertions(+), 72 deletions(-) diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index a073bdab91..1b3498cc60 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -11,16 +11,6 @@ namespace base_device { -// struct CPU; -// struct GPU; - -// enum AbacusDevice_t -// { -// UnKnown, -// CpuDevice, -// GpuDevice -// }; - template base_device::AbacusDevice_t get_device_type(const Device* dev); @@ -122,4 +112,4 @@ static __inline__ __device__ double atomicAdd(double* address, double val) } #endif -#endif // MODULE_DEVICE_H_ \ No newline at end of file +#endif // MODULE_DEVICE_H_ diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index dc2e1cfc47..425dc2ef23 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -1,6 +1,6 @@ #include "esolver_ks_pw.h" -#include "source_base/formatter.h" +//#include "source_base/formatter.h" #include "source_base/global_variable.h" #include "source_base/kernels/math_kernel_op.h" #include "source_base/memory.h" @@ -14,23 +14,14 @@ #include "source_hsolver/diago_iter_assist.h" #include "source_hsolver/hsolver_pw.h" #include "source_hsolver/kernels/dngvd_op.h" -//#include "source_io/berryphase.h" -//#include "source_io/cal_ldos.h" -//#include "source_io/get_pchg_pw.h" -//#include "source_io/get_wf_pw.h" #include "source_io/module_parameter/parameter.h" -//#include "source_io/numerical_basis.h" -//#include "source_io/numerical_descriptor.h" -//#include "source_io/to_wannier90_pw.h" -//#include "source_io/write_dos_pw.h" -//#include "source_io/write_wfc_pw.h" #include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_pw/module_pwdft/onsite_projector.h" #include "source_lcao/module_dftu/dftu.h" #include "source_pw/module_pwdft/VSep_in_pw.h" -#include "source_pw/module_pwdft/elecond.h" +//#include "source_pw/module_pwdft/elecond.h" #include "source_pw/module_pwdft/forces.h" #include "source_pw/module_pwdft/hamilt_pw.h" -//#include "source_pw/module_pwdft/onsite_projector.h" #include "source_pw/module_pwdft/stress_pw.h" #include @@ -654,8 +645,8 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } - ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, - this->kv, this->pw_wfc, PARAM.inp); +// ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + // this->kv, this->pw_wfc, PARAM.inp); } template @@ -690,9 +681,9 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->psi[0].size()); } - ModuleIO::ctrl_scf_pw(this->pelec, this->chr, this->kv, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, - this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); +// ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, + // this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, + // this->__kspw_psi, /* this->ctx,*/ this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -776,10 +767,10 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) //---------------------------------------------------------- ESolver_KS::after_all_runners(ucell); - ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, - this->pw_rho, this->pw_rhod, this->chr, this->psi, - this->kspw_psi, this->__kspw_psi, this->sf, - this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); + //ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, + // this->pw_rho, this->pw_rhod, this->chr, this->psi, + // this->kspw_psi, this->__kspw_psi, this->sf, + // this->ppcell, this->solvent, /* this->ctx,*/ this->Pgrid, PARAM.inp); } diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 7d1aa48869..65ab3c2583 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -1,25 +1,20 @@ #include "source_io/ctrl_output_pw.h" #include "source_io/write_wfc_pw.h" // use write_wfc_pw +#include "source_io/write_dos_pw.h" // use write_dos_pw +#include "source_io/to_wannier90_pw.h" // wannier90 interface #include "source_pw/module_pwdft/onsite_projector.h" // use projector - -/* -#include "source_io/ctrl_output_fp.h" // use ctrl_output_fp() -#include "source_estate/module_charge/symmetry_rho.h" // use Symmetry_rho -#include "source_io/write_elecstat_pot.h" // use write_elecstat_pot -#include "source_io/write_elf.h" -#include "cube_io.h" // use write_vdata_palgrid -#include "source_hamilt/module_xc/xc_functional.h" // use XC_Functional - -#ifdef USE_LIBXC -#include "source_io/write_libxc_r.h" -#endif -*/ - -namespace ModuleIO -{ - -void ctrl_iter_pw(const int istep, +#include "source_io/numerical_basis.h" +#include "source_io/numerical_descriptor.h" +#include "source_io/cal_ldos.h" +#include "source_io/berryphase.h" +#include "source_lcao/module_deltaspin/spin_constrain.h" +#include "source_base/formatter.h" +#include "source_io/get_pchg_pw.h" +#include "source_io/get_wf_pw.h" +#include "source_pw/module_pwdft/elecond.h" + +void ModuleIO::ctrl_iter_pw(const int istep, const int iter, const double &conv_esolver, psi::Psi, base_device::DEVICE_CPU>* psi, @@ -82,7 +77,9 @@ void ctrl_iter_pw(const int istep, template -void ctrl_scf_pw(elecstate::ElecState* pelec, +void ModuleIO::ctrl_scf_pw(const int istep, + const UnitCell& ucell, + elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, const ModulePW::PW_Basis_K *pw_wfc, @@ -92,7 +89,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, - const Device* ctx, +// const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -169,8 +166,11 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, ? new psi::Psi, Device>(kspw_psi[0]) : reinterpret_cast, Device>*>(kspw_psi); + const int nbands = kspw_psi->get_nbands(); + +/* ModuleIO::get_pchg_pw(inp.out_pchg, - kspw_psi->get_nbands(), + nbands, inp.nspin, pw_rhod->nxyz, chr.ngmc, @@ -186,6 +186,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, GlobalV::KPAR, GlobalV::MY_POOL, &chr); +*/ } @@ -236,9 +237,11 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, //------------------------------------------------------------------ if (inp.onsite_radius > 0) { // float type has not been implemented +/* auto* onsite_p = projectors::OnsiteProjector::get_instance(); onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), pelec->wg); +*/ } ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); @@ -246,19 +249,20 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, } template -void ctrl_runner_pw(UnitCell& ucell, +void ModuleIO::ctrl_runner_pw(UnitCell& ucell, elecstate::ElecState* pelec, ModulePW::PW_Basis_K* pw_wfc, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, + const K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const Device* ctx, +// const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp) { @@ -312,6 +316,7 @@ void ctrl_runner_pw(UnitCell& ucell, ? new psi::Psi, Device>(kspw_psi[0]) : reinterpret_cast, Device>*>(kspw_psi); +/* ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, kspw_psi->get_nbands(), @@ -326,6 +331,7 @@ void ctrl_runner_pw(UnitCell& ucell, kv, GlobalV::KPAR, GlobalV::MY_POOL); +*/ } //---------------------------------------------------------- @@ -333,6 +339,7 @@ void ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.cal_cond) { +/* EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, @@ -341,6 +348,7 @@ void ctrl_runner_pw(UnitCell& ucell, inp.cond_dt, inp.cond_nonlocal, pelec->wg); +*/ } #ifdef __MLALGO @@ -383,6 +391,8 @@ void ctrl_runner_pw(UnitCell& ucell, } template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + const int nstep, + const UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -391,13 +401,15 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, - const base_device::DEVICE_CPU* ctx, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device +// const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( + const int nstep, + const UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -406,13 +418,16 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_CPU>* kspw_psi, - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, - const base_device::DEVICE_CPU* ctx, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device +// const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); +#if ((defined __CUDA) || (defined __ROCM)) template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + const int nstep, + const UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -421,13 +436,15 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, - const base_device::DEVICE_CPU* ctx, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device +// const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( + const int nstep, + const UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -436,11 +453,83 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP const ModulePW::PW_Basis *pw_rhod, const ModulePW::PW_Basis_Big *pw_big, psi::Psi, base_device::DEVICE_CPU>* psi, - psi::Psi, base_device::DEVICE_GPU>* kspw_psi, - psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, - const base_device::DEVICE_CPU* ctx, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + //const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); +#endif + +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + const K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, +// const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); + +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + const K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, +// const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); +#if ((defined __CUDA) || (defined __ROCM)) +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + const K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, +// const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); -} // End ModuleIO +template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( + UnitCell& ucell, + elecstate::ElecState* pelec, + ModulePW::PW_Basis_K* pw_wfc, + ModulePW::PW_Basis* pw_rho, + ModulePW::PW_Basis* pw_rhod, + Charge &chr, + const K_Vectors &kv, + psi::Psi, base_device::DEVICE_CPU>* psi, + psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device + psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device + Structure_Factor &sf, + pseudopot_cell_vnl &ppcell, + surchem &solvent, +// const Device* ctx, + Parallel_Grid ¶_grid, + const Input_para& inp); +#endif diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 010749c805..492b3be6f8 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -1,7 +1,8 @@ #ifndef CTRL_OUTPUT_PW_H #define CTRL_OUTPUT_PW_H -#include "source_psi/psi_init.h" // about psi +#include "source_base/module_device/device.h" // use Device +#include "source_psi/psi.h" // define psi #include "source_estate/elecstate_lcao.h" // use pelec namespace ModuleIO @@ -16,7 +17,9 @@ void ctrl_iter_pw(const int istep, const Input_para& inp); template -void ctrl_scf_pw(elecstate::ElecState* pelec, +void ctrl_scf_pw(const int istep, + const UnitCell& ucell, + elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, const ModulePW::PW_Basis_K *pw_wfc, @@ -26,7 +29,7 @@ void ctrl_scf_pw(elecstate::ElecState* pelec, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, - const Device* ctx, +// const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -37,13 +40,14 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, + const K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, - const Device* ctx, +// const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp); From 55349809cf3961e9108ea67a7db80b4d89a087b0 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 16:31:43 +0800 Subject: [PATCH 13/18] finally I understand Devicegit add ../source/source_base/module_device/device.h ../source/source_io/ctrl_output_pw.cpp ../source/source_io/ctrl_output_pw.h! --- source/source_base/module_device/device.h | 1 - source/source_io/ctrl_output_pw.cpp | 6 +++--- source/source_io/ctrl_output_pw.h | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/source/source_base/module_device/device.h b/source/source_base/module_device/device.h index 1b3498cc60..7b8dd0c6ae 100644 --- a/source/source_base/module_device/device.h +++ b/source/source_base/module_device/device.h @@ -63,7 +63,6 @@ int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD); int stringCmp(const void* a, const void* b); #ifdef __CUDA - int set_device_by_rank(const MPI_Comm mpi_comm = MPI_COMM_WORLD); #endif diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 65ab3c2583..387330d9f2 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -89,7 +89,7 @@ void ModuleIO::ctrl_scf_pw(const int istep, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, -// const Device* ctx, + const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp) { @@ -403,7 +403,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device -// const base_device::DEVICE_CPU* ctx, + const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); @@ -420,7 +420,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device -// const base_device::DEVICE_CPU* ctx, + const base_device::DEVICE_CPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 492b3be6f8..2713be3719 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -29,7 +29,7 @@ void ctrl_scf_pw(const int istep, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, -// const Device* ctx, + const Device* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); From 9a1c9fdb602696e7ed9f0dc7c391717ef116b21f Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 16:36:31 +0800 Subject: [PATCH 14/18] update function variables --- source/source_io/ctrl_output_pw.cpp | 22 +++++++++++++++------- source/source_io/ctrl_output_pw.h | 9 ++++++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 387330d9f2..f69b4c5748 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -262,7 +262,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp) { @@ -390,6 +390,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModuleBase::timer::tick("ModuleIO", "ctrl_runner_pw"); } +// complex + CPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( const int nstep, const UnitCell& ucell, @@ -407,6 +408,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU const Parallel_Grid ¶_grid, const Input_para& inp); +// complex + CPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( const int nstep, const UnitCell& ucell, @@ -425,6 +427,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP const Input_para& inp); #if ((defined __CUDA) || (defined __ROCM)) +// complex + GPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( const int nstep, const UnitCell& ucell, @@ -438,10 +441,11 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device -// const base_device::DEVICE_CPU* ctx, + const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); +// complex + GPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( const int nstep, const UnitCell& ucell, @@ -455,11 +459,12 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GP psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device - //const base_device::DEVICE_CPU* ctx, + const base_device::DEVICE_GPU* ctx, const Parallel_Grid ¶_grid, const Input_para& inp); #endif +// complex + CPU template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( UnitCell& ucell, elecstate::ElecState* pelec, @@ -474,10 +479,11 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const base_device::DEVICE_CPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); +// complex + CPU template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_CPU>( UnitCell& ucell, elecstate::ElecState* pelec, @@ -492,11 +498,12 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const base_device::DEVICE_CPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); #if ((defined __CUDA) || (defined __ROCM)) +// complex + GPU template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( UnitCell& ucell, elecstate::ElecState* pelec, @@ -511,10 +518,11 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const base_device::DEVICE_GPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); +// complex + GPU template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_GPU>( UnitCell& ucell, elecstate::ElecState* pelec, @@ -529,7 +537,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const base_device::DEVICE_GPU* ctx, Parallel_Grid ¶_grid, const Input_para& inp); #endif diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 2713be3719..987ca2a92f 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -2,12 +2,13 @@ #define CTRL_OUTPUT_PW_H #include "source_base/module_device/device.h" // use Device -#include "source_psi/psi.h" // define psi -#include "source_estate/elecstate_lcao.h" // use pelec +#include "source_psi/psi.h" // define psi +#include "source_estate/elecstate_lcao.h" // use pelec namespace ModuleIO { +// print out information in 'iter_finish' in ESolver_KS_PW void ctrl_iter_pw(const int istep, const int iter, const double &conv_esolver, @@ -16,6 +17,7 @@ void ctrl_iter_pw(const int istep, const ModulePW::PW_Basis_K *pw_wfc, const Input_para& inp); +// print out information in 'after_scf' in ESolver_KS_PW template void ctrl_scf_pw(const int istep, const UnitCell& ucell, @@ -33,6 +35,7 @@ void ctrl_scf_pw(const int istep, const Parallel_Grid ¶_grid, const Input_para& inp); +// print out information in 'after_all_runners' in ESolver_KS_PW template void ctrl_runner_pw(UnitCell& ucell, elecstate::ElecState* pelec, @@ -47,7 +50,7 @@ void ctrl_runner_pw(UnitCell& ucell, Structure_Factor &sf, pseudopot_cell_vnl &ppcell, surchem &solvent, -// const Device* ctx, + const Device* ctx, Parallel_Grid ¶_grid, const Input_para& inp); From 40fa9f081014ec2f8d88fa7c6157cdae385b0b70 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 16:49:40 +0800 Subject: [PATCH 15/18] one step further --- source/source_io/ctrl_output_pw.cpp | 15 +++++++-------- source/source_io/ctrl_output_pw.h | 2 +- source/source_io/get_pchg_pw.h | 3 ++- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index f69b4c5748..4f6f04b50f 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -78,7 +78,7 @@ void ModuleIO::ctrl_iter_pw(const int istep, template void ModuleIO::ctrl_scf_pw(const int istep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -167,13 +167,13 @@ void ModuleIO::ctrl_scf_pw(const int istep, : reinterpret_cast, Device>*>(kspw_psi); const int nbands = kspw_psi->get_nbands(); + const int ngmc = chr.ngmc; -/* ModuleIO::get_pchg_pw(inp.out_pchg, nbands, inp.nspin, pw_rhod->nxyz, - chr.ngmc, + ngmc, &ucell, __kspw_psi, pw_rhod, @@ -186,7 +186,6 @@ void ModuleIO::ctrl_scf_pw(const int istep, GlobalV::KPAR, GlobalV::MY_POOL, &chr); -*/ } @@ -393,7 +392,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, // complex + CPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( const int nstep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -411,7 +410,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU // complex + CPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CPU>( const int nstep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -430,7 +429,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_CP // complex + GPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( const int nstep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, @@ -448,7 +447,7 @@ template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU // complex + GPU template void ModuleIO::ctrl_scf_pw, base_device::DEVICE_GPU>( const int nstep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 987ca2a92f..2cf9cc41ae 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -20,7 +20,7 @@ void ctrl_iter_pw(const int istep, // print out information in 'after_scf' in ESolver_KS_PW template void ctrl_scf_pw(const int istep, - const UnitCell& ucell, + UnitCell& ucell, elecstate::ElecState* pelec, const Charge &chr, const K_Vectors &kv, diff --git a/source/source_io/get_pchg_pw.h b/source/source_io/get_pchg_pw.h index 7c2d14fe6a..2a61c77aa3 100644 --- a/source/source_io/get_pchg_pw.h +++ b/source/source_io/get_pchg_pw.h @@ -1,7 +1,8 @@ #ifndef GET_PCHG_PW_H #define GET_PCHG_PW_H -#include "cube_io.h" +#include "source_io/cube_io.h" +#include "source_estate/module_charge/symmetry_rho.h" namespace ModuleIO { From 155c18d051e5c1fd4e7526f25b6c1abd2d32d28f Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 17:01:27 +0800 Subject: [PATCH 16/18] move on --- source/source_io/ctrl_output_pw.cpp | 17 ++++++----------- source/source_io/ctrl_output_pw.h | 2 +- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 4f6f04b50f..67d143a274 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -236,11 +236,9 @@ void ModuleIO::ctrl_scf_pw(const int istep, //------------------------------------------------------------------ if (inp.onsite_radius > 0) { // float type has not been implemented -/* auto* onsite_p = projectors::OnsiteProjector::get_instance(); onsite_p->cal_occupations(reinterpret_cast, Device>*>(kspw_psi), pelec->wg); -*/ } ModuleBase::timer::tick("ModuleIO", "ctrl_scf_pw"); @@ -254,7 +252,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, @@ -315,7 +313,6 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, ? new psi::Psi, Device>(kspw_psi[0]) : reinterpret_cast, Device>*>(kspw_psi); -/* ModuleIO::get_wf_pw(inp.out_wfc_norm, inp.out_wfc_re_im, kspw_psi->get_nbands(), @@ -330,7 +327,6 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, kv, GlobalV::KPAR, GlobalV::MY_POOL); -*/ } //---------------------------------------------------------- @@ -338,7 +334,7 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, //---------------------------------------------------------- if (inp.cal_cond) { -/* + using Real = typename GetTypeReal::type; EleCond elec_cond(&ucell, &kv, pelec, pw_wfc, kspw_psi, &ppcell); elec_cond.KG(inp.cond_smear, inp.cond_fwhm, @@ -347,7 +343,6 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell, inp.cond_dt, inp.cond_nonlocal, pelec->wg); -*/ } #ifdef __MLALGO @@ -471,7 +466,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device @@ -490,7 +485,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_CPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_CPU>* __kspw_psi, // Device @@ -510,7 +505,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE_ ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device @@ -529,7 +524,7 @@ template void ModuleIO::ctrl_runner_pw, base_device::DEVICE ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi, base_device::DEVICE_GPU>* kspw_psi, // T and Device psi::Psi, base_device::DEVICE_GPU>* __kspw_psi, // Device diff --git a/source/source_io/ctrl_output_pw.h b/source/source_io/ctrl_output_pw.h index 2cf9cc41ae..87fea245b0 100644 --- a/source/source_io/ctrl_output_pw.h +++ b/source/source_io/ctrl_output_pw.h @@ -43,7 +43,7 @@ void ctrl_runner_pw(UnitCell& ucell, ModulePW::PW_Basis* pw_rho, ModulePW::PW_Basis* pw_rhod, Charge &chr, - const K_Vectors &kv, + K_Vectors &kv, psi::Psi, base_device::DEVICE_CPU>* psi, psi::Psi* kspw_psi, psi::Psi, Device>* __kspw_psi, From 9c1c37672cff1387ae28dd52ae23b1b463829c42 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sat, 27 Sep 2025 17:13:21 +0800 Subject: [PATCH 17/18] update codes, done! --- source/source_esolver/esolver_ks_pw.cpp | 54 ++++++++----------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 425dc2ef23..daba4282c5 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -1,6 +1,5 @@ #include "esolver_ks_pw.h" -//#include "source_base/formatter.h" #include "source_base/global_variable.h" #include "source_base/kernels/math_kernel_op.h" #include "source_base/memory.h" @@ -19,17 +18,12 @@ #include "source_pw/module_pwdft/onsite_projector.h" #include "source_lcao/module_dftu/dftu.h" #include "source_pw/module_pwdft/VSep_in_pw.h" -//#include "source_pw/module_pwdft/elecond.h" #include "source_pw/module_pwdft/forces.h" #include "source_pw/module_pwdft/hamilt_pw.h" #include "source_pw/module_pwdft/stress_pw.h" #include -//#ifdef __MLALGO -//#include "source_io/write_mlkedf_descriptors.h" -//#endif - #include #include @@ -55,7 +49,6 @@ ESolver_KS_PW::ESolver_KS_PW() template ESolver_KS_PW::~ESolver_KS_PW() { - // delete Hamilt this->deallocate_hamilt(); @@ -495,6 +488,7 @@ void ESolver_KS_PW::hamilt2rho_single(UnitCell& ucell, const int iste skip_solve = true; } } + if (!skip_solve) { hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, @@ -645,8 +639,8 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int } } -// ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, - // this->kv, this->pw_wfc, PARAM.inp); + ModuleIO::ctrl_iter_pw(istep, iter, conv_esolver, this->psi, + this->kv, this->pw_wfc, PARAM.inp); } template @@ -681,9 +675,9 @@ void ESolver_KS_PW::after_scf(UnitCell& ucell, const int istep, const this->psi[0].size()); } -// ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, - // this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, - // this->__kspw_psi, /* this->ctx,*/ this->Pgrid, PARAM.inp); + ModuleIO::ctrl_scf_pw(istep, ucell, this->pelec, this->chr, this->kv, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->pw_big, this->psi, this->kspw_psi, + this->__kspw_psi, this->ctx, this->Pgrid, PARAM.inp); ModuleBase::timer::tick("ESolver_KS_PW", "after_scf"); } @@ -710,18 +704,9 @@ void ESolver_KS_PW::cal_force(UnitCell& ucell, ModuleBase::matrix& fo : reinterpret_cast, Device>*>(this->kspw_psi); // Calculate forces - ff.cal_force(ucell, - force, - *this->pelec, - this->pw_rhod, - &ucell.symm, - &this->sf, - this->solvent, - &this->locpp, - &this->ppcell, - &this->kv, - this->pw_wfc, - this->__kspw_psi); + ff.cal_force(ucell, force, *this->pelec, this->pw_rhod, &ucell.symm, + &this->sf, this->solvent, &this->locpp, &this->ppcell, + &this->kv, this->pw_wfc, this->__kspw_psi); } template @@ -738,16 +723,9 @@ void ESolver_KS_PW::cal_stress(UnitCell& ucell, ModuleBase::matrix& s this->__kspw_psi = PARAM.inp.precision == "single" ? new psi::Psi, Device>(this->kspw_psi[0]) : reinterpret_cast, Device>*>(this->kspw_psi); - ss.cal_stress(stress, - ucell, - this->locpp, - this->ppcell, - this->pw_rhod, - &ucell.symm, - &this->sf, - &this->kv, - this->pw_wfc, - this->__kspw_psi); + + ss.cal_stress(stress, ucell, this->locpp, this->ppcell, this->pw_rhod, + &ucell.symm, &this->sf, &this->kv, this->pw_wfc, this->__kspw_psi); // external stress double unit_transform = 0.0; @@ -767,10 +745,10 @@ void ESolver_KS_PW::after_all_runners(UnitCell& ucell) //---------------------------------------------------------- ESolver_KS::after_all_runners(ucell); - //ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, - // this->pw_rho, this->pw_rhod, this->chr, this->psi, - // this->kspw_psi, this->__kspw_psi, this->sf, - // this->ppcell, this->solvent, /* this->ctx,*/ this->Pgrid, PARAM.inp); + ModuleIO::ctrl_runner_pw(ucell, this->pelec, this->pw_wfc, + this->pw_rho, this->pw_rhod, this->chr, this->kv, this->psi, + this->kspw_psi, this->__kspw_psi, this->sf, + this->ppcell, this->solvent, this->ctx, this->Pgrid, PARAM.inp); } From e6623f1d2dae1df4654a39a5cc622471580ff1a0 Mon Sep 17 00:00:00 2001 From: mohanchen Date: Sun, 28 Sep 2025 21:34:17 +0800 Subject: [PATCH 18/18] fix bug --- source/source_io/ctrl_output_pw.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/source_io/ctrl_output_pw.cpp b/source/source_io/ctrl_output_pw.cpp index 67d143a274..a8c588996a 100644 --- a/source/source_io/ctrl_output_pw.cpp +++ b/source/source_io/ctrl_output_pw.cpp @@ -14,6 +14,10 @@ #include "source_io/get_wf_pw.h" #include "source_pw/module_pwdft/elecond.h" +#ifdef __MLALGO +#include "source_io/write_mlkedf_descriptors.h" +#endif + void ModuleIO::ctrl_iter_pw(const int istep, const int iter, const double &conv_esolver,