Skip to content

Commit 130f486

Browse files
committed
update after_scf in esolver
1 parent cb37d7d commit 130f486

File tree

5 files changed

+301
-240
lines changed

5 files changed

+301
-240
lines changed

source/module_esolver/esolver_fp.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,18 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
131131
{
132132
ModuleBase::TITLE("ESolver_FP", "after_scf");
133133

134-
// 0) output convergence information
134+
// 1) output convergence information
135135
ModuleIO::output_convergence_after_scf(conv_esolver, this->pelec->f_en.etot);
136136

137-
// 1) write fermi energy
137+
// 2) write fermi energy
138138
ModuleIO::output_efermi(conv_esolver, this->pelec->eferm.ef);
139139

140-
// 2) update delta rho for charge extrapolation
140+
// 3) update delta rho for charge extrapolation
141141
CE.update_delta_rho(ucell, &(this->chr), &(this->sf));
142142

143143
if (istep % PARAM.inp.out_interval == 0)
144144
{
145-
// 3) write charge density
145+
// 4) write charge density
146146
if (PARAM.inp.out_chg[0] > 0)
147147
{
148148
for (int is = 0; is < PARAM.inp.nspin; is++)
@@ -184,7 +184,7 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
184184
}
185185
}
186186

187-
// 4) write potential
187+
// 5) write potential
188188
if (PARAM.inp.out_pot == 1 || PARAM.inp.out_pot == 3)
189189
{
190190
for (int is = 0; is < PARAM.inp.nspin; is++)
@@ -220,7 +220,7 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
220220
this->solvent);
221221
}
222222

223-
// 5) write ELF
223+
// 6) write ELF
224224
if (PARAM.inp.out_elf[0] > 0)
225225
{
226226
this->pelec->charge->cal_elf = true;

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,20 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
225225
{
226226
if (this->psi_laststep == nullptr)
227227
{
228+
int ncol_tmp = 0;
229+
int nrow_tmp = 0;
228230
#ifdef __MPI
229-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true);
231+
ncol_tmp = ncol_nbands;
232+
nrow_tmp = nrow;
230233
#else
231-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, kv.ngk, true);
234+
ncol_tmp = nbands;
235+
nrow_tmp = nlocal;
232236
#endif
237+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_tmp, nrow_tmp, kv.ngk, true);
238+
233239
}
234240

241+
// allocate memory for Hk_laststep and Sk_laststep
235242
if (td_htype == 1)
236243
{
237244
// Length of Hk_laststep and Sk_laststep, nlocal * nlocal for global, nloc for local
@@ -259,11 +266,14 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
259266
}
260267
}
261268

269+
// put information to Hk_laststep and Sk_laststep
262270
for (int ik = 0; ik < kv.get_nks(); ++ik)
263271
{
264272
this->psi->fix_k(ik);
265273
this->psi_laststep->fix_k(ik);
266-
int size0 = psi->get_nbands() * psi->get_nbasis();
274+
275+
// copy the data from psi to psi_laststep
276+
const int size0 = psi->get_nbands() * psi->get_nbasis();
267277
for (int index = 0; index < size0; ++index)
268278
{
269279
psi_laststep[0].get_pointer()[index] = psi[0].get_pointer()[index];
@@ -273,7 +283,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
273283
if (td_htype == 1)
274284
{
275285
this->p_hamilt->updateHk(ik);
276-
hamilt::MatrixBlock<complex<double>> h_mat, s_mat;
286+
hamilt::MatrixBlock<complex<double>> h_mat;
287+
hamilt::MatrixBlock<complex<double>> s_mat;
277288
this->p_hamilt->matrix(h_mat, s_mat);
278289

279290
if (use_tensor && use_lapack)
@@ -285,7 +296,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
285296
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
286297
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
287298

288-
Matrix_g<std::complex<double>> h_mat_g, s_mat_g; // Global matrix structure
299+
Matrix_g<std::complex<double>> h_mat_g; // Global matrix structure
300+
Matrix_g<std::complex<double>> s_mat_g; // Global matrix structure
289301

290302
// Collect H matrix
291303
gatherMatrix(myid, 0, h_mat, h_mat_g);
@@ -343,6 +355,9 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
343355
ModuleBase::TITLE("ESolver_LCAO_TDDFT", "after_scf");
344356
ModuleBase::timer::tick("ESolver_LCAO_TDDFT", "after_scf");
345357

358+
ESolver_KS_LCAO<std::complex<double>, double>::after_scf(ucell, istep, conv_esolver);
359+
360+
// (1) write dipole information
346361
for (int is = 0; is < PARAM.inp.nspin; is++)
347362
{
348363
if (PARAM.inp.out_dipole == 1)
@@ -357,6 +372,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
357372
ss_dipole.str());
358373
}
359374
}
375+
376+
// (2) write current information
360377
if (TD_Velocity::out_current == true)
361378
{
362379
elecstate::DensityMatrix<std::complex<double>, double>* tmp_DM
@@ -373,7 +390,7 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
373390
orb_,
374391
this->RA);
375392
}
376-
ESolver_KS_LCAO<std::complex<double>, double>::after_scf(ucell, istep, conv_esolver);
393+
377394

378395
ModuleBase::timer::tick("ESolver_LCAO_TDDFT", "after_scf");
379396
}
@@ -385,14 +402,18 @@ void ESolver_KS_LCAO_TDDFT<Device>::weight_dm_rho()
385402
{
386403
this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec);
387404
}
405+
406+
// calculate Eband energy
388407
this->pelec->calEBand();
389408

409+
// calculate the density matrix
390410
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
391411

392412
auto _pes = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec);
393413
elecstate::cal_dm_psi(_pes->DM->get_paraV_pointer(), _pes->wg, this->psi[0], *(_pes->DM));
394414
_pes->DM->cal_DMR();
395415

416+
// get the real-space charge density
396417
this->pelec->psiToRho(this->psi[0]);
397418
}
398419

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -638,33 +638,44 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
638638
ModuleBase::TITLE("ESolver_KS_PW", "after_scf");
639639
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
640640

641-
// 1) calculate the kinetic energy density tau, sunliang 2024-09-18
641+
//------------------------------------------------------------------
642+
// 1) call after_scf() of ESolver_KS
643+
//------------------------------------------------------------------
644+
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);
645+
646+
//------------------------------------------------------------------
647+
// 2) calculate the kinetic energy density tau in pw basis
648+
// sunliang 2024-09-18
649+
//------------------------------------------------------------------
642650
if (PARAM.inp.out_elf[0] > 0)
643651
{
644652
this->pelec->cal_tau(*(this->psi));
645653
}
646654

647-
// 2) call after_scf() of ESolver_KS
648-
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);
649-
655+
//------------------------------------------------------------------
650656
// 3) output wavefunctions in pw basis
657+
//------------------------------------------------------------------
651658
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
652659
{
653660
std::stringstream ssw;
654661
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
655662
ModuleIO::write_wfc_pw(ssw.str(), this->psi[0], this->kv, this->pw_wfc);
656663
}
657664

665+
//------------------------------------------------------------------
658666
// 4) transfer data from GPU to CPU in pw basis
659667
// a question: the wavefunctions have been output, then the data transfer occurs? mohan 20250302
668+
//------------------------------------------------------------------
660669
if (this->device == base_device::GpuDevice)
661670
{
662671
castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
663672
this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(),
664673
this->psi[0].size());
665674
}
666675

676+
//------------------------------------------------------------------
667677
// 5) calculate band-decomposed (partial) charge density in pw basis
678+
//------------------------------------------------------------------
668679
const std::vector<int> bands_to_print = PARAM.inp.bands_to_print;
669680
if (bands_to_print.size() > 0)
670681
{
@@ -691,7 +702,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
691702
PARAM.inp.if_separate_k);
692703
}
693704

694-
//! 6) calculate Wannier functions in PW basis
705+
//------------------------------------------------------------------
706+
//! 6) calculate Wannier functions in pw basis
707+
//------------------------------------------------------------------
695708
if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90)
696709
{
697710
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation");
@@ -707,7 +720,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
707720
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation");
708721
}
709722

710-
//! 7) calculate Berry phase polarization
723+
//------------------------------------------------------------------
724+
//! 7) calculate Berry phase polarization in pw basis
725+
//------------------------------------------------------------------
711726
if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1)
712727
{
713728
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization");
@@ -716,8 +731,10 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
716731
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization");
717732
}
718733

719-
// 8) write spin constrian results
734+
//------------------------------------------------------------------
735+
// 8) write spin constrian results in pw basis
720736
// spin constrain calculations, write atomic magnetization and magnetic force.
737+
//------------------------------------------------------------------
721738
if (PARAM.inp.sc_mag_switch)
722739
{
723740
spinconstrain::SpinConstrain<std::complex<double>>& sc
@@ -726,7 +743,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
726743
sc.print_Mag_Force(GlobalV::ofs_running);
727744
}
728745

746+
//------------------------------------------------------------------
729747
// 9) write onsite occupations for charge and magnetizations
748+
//------------------------------------------------------------------
730749
if (PARAM.inp.onsite_radius > 0)
731750
{ // float type has not been implemented
732751
auto* onsite_p = projectors::OnsiteProjector<double, Device>::get_instance();

source/module_esolver/esolver_of.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,30 +497,46 @@ void ESolver_OF::after_opt(const int istep, UnitCell& ucell, const bool conv_eso
497497
ModuleBase::TITLE("ESolver_OF", "after_opt");
498498
ModuleBase::timer::tick("ESolver_OF", "after_opt");
499499

500-
// 1) calculate the kinetic energy density
500+
//------------------------------------------------------------------
501+
// 1) call after_scf() of ESolver_FP
502+
//------------------------------------------------------------------
503+
ESolver_FP::after_scf(ucell, istep, conv_esolver);
504+
505+
//------------------------------------------------------------------
506+
// 2) calculate kinetic energy density and ELF
507+
//------------------------------------------------------------------
501508
if (PARAM.inp.out_elf[0] > 0)
502509
{
503510
this->kinetic_energy_density(this->pelec->charge->rho, this->pphi_, this->pelec->charge->kin_r);
504511
}
505512

513+
// should not be here? mohan note 2025-03-03
506514
for (int ir = 0; ir < this->pw_rho->nrxx; ++ir)
507515
{
508516
this->pelec->charge->rho_save[0][ir] = this->pelec->charge->rho[0][ir];
509517
}
510518

511519
#ifdef __MLKEDF
520+
//------------------------------------------------------------------
512521
// Check the positivity of Pauli energy
522+
//------------------------------------------------------------------
513523
if (this->of_kinetic_ == "ml")
514524
{
515525
this->tf_->get_energy(this->pelec->charge->rho);
516-
std::cout << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl;
526+
527+
std::cout << "ML Term = " << this->ml_->ml_energy
528+
<< " Ry, TF Term = " << this->tf_->tf_energy
529+
<< " Ry." << std::endl;
530+
517531
if (this->ml_->ml_energy >= this->tf_->tf_energy)
518532
{
519533
std::cout << "WARNING: ML >= TF" << std::endl;
520534
}
521535
}
522536

537+
//------------------------------------------------------------------
523538
// Generate data if needed
539+
//------------------------------------------------------------------
524540
if (PARAM.inp.of_ml_gene_data)
525541
{
526542
this->pelec->pot->update_from_charge(pelec->charge, &ucell); // Hartree + XC + external
@@ -541,8 +557,6 @@ void ESolver_OF::after_opt(const int istep, UnitCell& ucell, const bool conv_eso
541557
this->ml_->generateTrainData(pelec->charge->rho, *(this->wt_), *(this->tf_), this->pw_rho, vr_eff);
542558
}
543559
#endif
544-
// 2) call after_scf() of ESolver_FP
545-
ESolver_FP::after_scf(ucell, istep, conv_esolver);
546560

547561
ModuleBase::timer::tick("ESolver_OF", "after_opt");
548562
}

0 commit comments

Comments
 (0)