Skip to content

Commit 071e8c3

Browse files
Update after_scf in ESolver (#5957)
* update lcao_after_scf, support RPA and LRI * update lcao_after_scf * update * update after_scf in esolver * fix ELF * fix ELF in LCAO --------- Co-authored-by: Hongxu Ren <[email protected]>
1 parent 7f9dc96 commit 071e8c3

File tree

5 files changed

+296
-232
lines changed

5 files changed

+296
-232
lines changed

source/module_esolver/esolver_fp.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,18 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
134134
{
135135
ModuleBase::TITLE("ESolver_FP", "after_scf");
136136

137-
// 0) output convergence information
137+
// 1) output convergence information
138138
ModuleIO::output_convergence_after_scf(conv_esolver, this->pelec->f_en.etot);
139139

140-
// 1) write fermi energy
140+
// 2) write fermi energy
141141
ModuleIO::output_efermi(conv_esolver, this->pelec->eferm.ef);
142142

143-
// 2) update delta rho for charge extrapolation
143+
// 3) update delta rho for charge extrapolation
144144
CE.update_delta_rho(ucell, &(this->chr), &(this->sf));
145145

146146
if (istep % PARAM.inp.out_interval == 0)
147147
{
148-
// 3) write charge density
148+
// 4) write charge density
149149
if (PARAM.inp.out_chg[0] > 0)
150150
{
151151
for (int is = 0; is < PARAM.inp.nspin; is++)
@@ -187,7 +187,7 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
187187
}
188188
}
189189

190-
// 4) write potential
190+
// 5) write potential
191191
if (PARAM.inp.out_pot == 1 || PARAM.inp.out_pot == 3)
192192
{
193193
for (int is = 0; is < PARAM.inp.nspin; is++)
@@ -223,7 +223,7 @@ void ESolver_FP::after_scf(UnitCell& ucell, const int istep, const bool conv_eso
223223
this->solvent);
224224
}
225225

226-
// 5) write ELF
226+
// 6) write ELF
227227
if (PARAM.inp.out_elf[0] > 0)
228228
{
229229
this->pelec->charge->cal_elf = true;

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,20 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
225225
{
226226
if (this->psi_laststep == nullptr)
227227
{
228+
int ncol_tmp = 0;
229+
int nrow_tmp = 0;
228230
#ifdef __MPI
229-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_nbands, nrow, kv.ngk, true);
231+
ncol_tmp = ncol_nbands;
232+
nrow_tmp = nrow;
230233
#else
231-
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), nbands, nlocal, kv.ngk, true);
234+
ncol_tmp = nbands;
235+
nrow_tmp = nlocal;
232236
#endif
237+
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_tmp, nrow_tmp, kv.ngk, true);
238+
233239
}
234240

241+
// allocate memory for Hk_laststep and Sk_laststep
235242
if (td_htype == 1)
236243
{
237244
// Length of Hk_laststep and Sk_laststep, nlocal * nlocal for global, nloc for local
@@ -259,11 +266,14 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
259266
}
260267
}
261268

269+
// put information to Hk_laststep and Sk_laststep
262270
for (int ik = 0; ik < kv.get_nks(); ++ik)
263271
{
264272
this->psi->fix_k(ik);
265273
this->psi_laststep->fix_k(ik);
266-
int size0 = psi->get_nbands() * psi->get_nbasis();
274+
275+
// copy the data from psi to psi_laststep
276+
const int size0 = psi->get_nbands() * psi->get_nbasis();
267277
for (int index = 0; index < size0; ++index)
268278
{
269279
psi_laststep[0].get_pointer()[index] = psi[0].get_pointer()[index];
@@ -273,7 +283,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
273283
if (td_htype == 1)
274284
{
275285
this->p_hamilt->updateHk(ik);
276-
hamilt::MatrixBlock<complex<double>> h_mat, s_mat;
286+
hamilt::MatrixBlock<complex<double>> h_mat;
287+
hamilt::MatrixBlock<complex<double>> s_mat;
277288
this->p_hamilt->matrix(h_mat, s_mat);
278289

279290
if (use_tensor && use_lapack)
@@ -285,7 +296,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
285296
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
286297
MPI_Comm_size(MPI_COMM_WORLD, &num_procs);
287298

288-
Matrix_g<std::complex<double>> h_mat_g, s_mat_g; // Global matrix structure
299+
Matrix_g<std::complex<double>> h_mat_g; // Global matrix structure
300+
Matrix_g<std::complex<double>> s_mat_g; // Global matrix structure
289301

290302
// Collect H matrix
291303
gatherMatrix(myid, 0, h_mat, h_mat_g);
@@ -343,6 +355,9 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
343355
ModuleBase::TITLE("ESolver_LCAO_TDDFT", "after_scf");
344356
ModuleBase::timer::tick("ESolver_LCAO_TDDFT", "after_scf");
345357

358+
ESolver_KS_LCAO<std::complex<double>, double>::after_scf(ucell, istep, conv_esolver);
359+
360+
// (1) write dipole information
346361
for (int is = 0; is < PARAM.inp.nspin; is++)
347362
{
348363
if (PARAM.inp.out_dipole == 1)
@@ -357,6 +372,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
357372
ss_dipole.str());
358373
}
359374
}
375+
376+
// (2) write current information
360377
if (TD_Velocity::out_current == true)
361378
{
362379
elecstate::DensityMatrix<std::complex<double>, double>* tmp_DM
@@ -373,7 +390,7 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
373390
orb_,
374391
this->RA);
375392
}
376-
ESolver_KS_LCAO<std::complex<double>, double>::after_scf(ucell, istep, conv_esolver);
393+
377394

378395
ModuleBase::timer::tick("ESolver_LCAO_TDDFT", "after_scf");
379396
}
@@ -385,14 +402,18 @@ void ESolver_KS_LCAO_TDDFT<Device>::weight_dm_rho()
385402
{
386403
this->pelec->fixed_weights(PARAM.inp.ocp_kb, PARAM.inp.nbands, PARAM.inp.nelec);
387404
}
405+
406+
// calculate Eband energy
388407
this->pelec->calEBand();
389408

409+
// calculate the density matrix
390410
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
391411

392412
auto _pes = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec);
393413
elecstate::cal_dm_psi(_pes->DM->get_paraV_pointer(), _pes->wg, this->psi[0], *(_pes->DM));
394414
_pes->DM->cal_DMR();
395415

416+
// get the real-space charge density
396417
this->pelec->psiToRho(this->psi[0]);
397418
}
398419

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,33 +551,45 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
551551
ModuleBase::TITLE("ESolver_KS_PW", "after_scf");
552552
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
553553

554-
// 1) calculate the kinetic energy density tau, sunliang 2024-09-18
554+
//------------------------------------------------------------------
555+
// 1) calculate the kinetic energy density tau in pw basis
556+
// sunliang 2024-09-18
557+
//------------------------------------------------------------------
555558
if (PARAM.inp.out_elf[0] > 0)
556559
{
557560
this->pelec->cal_tau(*(this->psi));
558561
}
559562

563+
//------------------------------------------------------------------
560564
// 2) call after_scf() of ESolver_KS
565+
//------------------------------------------------------------------
561566
ESolver_KS<T, Device>::after_scf(ucell, istep, conv_esolver);
562567

568+
569+
//------------------------------------------------------------------
563570
// 3) output wavefunctions in pw basis
571+
//------------------------------------------------------------------
564572
if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2)
565573
{
566574
std::stringstream ssw;
567575
ssw << PARAM.globalv.global_out_dir << "WAVEFUNC";
568576
ModuleIO::write_wfc_pw(ssw.str(), this->psi[0], this->kv, this->pw_wfc);
569577
}
570578

579+
//------------------------------------------------------------------
571580
// 4) transfer data from GPU to CPU in pw basis
572581
// a question: the wavefunctions have been output, then the data transfer occurs? mohan 20250302
582+
//------------------------------------------------------------------
573583
if (this->device == base_device::GpuDevice)
574584
{
575585
castmem_2d_d2h_op()(this->psi[0].get_pointer() - this->psi[0].get_psi_bias(),
576586
this->kspw_psi[0].get_pointer() - this->kspw_psi[0].get_psi_bias(),
577587
this->psi[0].size());
578588
}
579589

590+
//------------------------------------------------------------------
580591
// 5) calculate band-decomposed (partial) charge density in pw basis
592+
//------------------------------------------------------------------
581593
const std::vector<int> bands_to_print = PARAM.inp.bands_to_print;
582594
if (bands_to_print.size() > 0)
583595
{
@@ -604,7 +616,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
604616
PARAM.inp.if_separate_k);
605617
}
606618

607-
//! 6) calculate Wannier functions in PW basis
619+
//------------------------------------------------------------------
620+
//! 6) calculate Wannier functions in pw basis
621+
//------------------------------------------------------------------
608622
if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90)
609623
{
610624
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wannier functions calculation");
@@ -620,7 +634,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
620634
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wannier functions calculation");
621635
}
622636

623-
//! 7) calculate Berry phase polarization
637+
//------------------------------------------------------------------
638+
//! 7) calculate Berry phase polarization in pw basis
639+
//------------------------------------------------------------------
624640
if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1)
625641
{
626642
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase polarization");
@@ -629,8 +645,10 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
629645
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase polarization");
630646
}
631647

632-
// 8) write spin constrian results
648+
//------------------------------------------------------------------
649+
// 8) write spin constrian results in pw basis
633650
// spin constrain calculations, write atomic magnetization and magnetic force.
651+
//------------------------------------------------------------------
634652
if (PARAM.inp.sc_mag_switch)
635653
{
636654
spinconstrain::SpinConstrain<std::complex<double>>& sc
@@ -639,7 +657,9 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
639657
sc.print_Mag_Force(GlobalV::ofs_running);
640658
}
641659

660+
//------------------------------------------------------------------
642661
// 9) write onsite occupations for charge and magnetizations
662+
//------------------------------------------------------------------
643663
if (PARAM.inp.onsite_radius > 0)
644664
{ // float type has not been implemented
645665
auto* onsite_p = projectors::OnsiteProjector<double, Device>::get_instance();

source/module_esolver/esolver_of.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,30 +489,47 @@ void ESolver_OF::after_opt(const int istep, UnitCell& ucell, const bool conv_eso
489489
ModuleBase::TITLE("ESolver_OF", "after_opt");
490490
ModuleBase::timer::tick("ESolver_OF", "after_opt");
491491

492-
// 1) calculate the kinetic energy density
492+
//------------------------------------------------------------------
493+
// 1) calculate kinetic energy density and ELF
494+
//------------------------------------------------------------------
493495
if (PARAM.inp.out_elf[0] > 0)
494496
{
495497
this->kinetic_energy_density(this->pelec->charge->rho, this->pphi_, this->pelec->charge->kin_r);
496498
}
497499

500+
//------------------------------------------------------------------
501+
// 2) call after_scf() of ESolver_FP
502+
//------------------------------------------------------------------
503+
ESolver_FP::after_scf(ucell, istep, conv_esolver);
504+
505+
506+
// should not be here? mohan note 2025-03-03
498507
for (int ir = 0; ir < this->pw_rho->nrxx; ++ir)
499508
{
500509
this->pelec->charge->rho_save[0][ir] = this->pelec->charge->rho[0][ir];
501510
}
502511

503512
#ifdef __MLKEDF
513+
//------------------------------------------------------------------
504514
// Check the positivity of Pauli energy
515+
//------------------------------------------------------------------
505516
if (this->of_kinetic_ == "ml")
506517
{
507518
this->tf_->get_energy(this->pelec->charge->rho);
508-
std::cout << "ML Term = " << this->ml_->ml_energy << " Ry, TF Term = " << this->tf_->tf_energy << " Ry." << std::endl;
519+
520+
std::cout << "ML Term = " << this->ml_->ml_energy
521+
<< " Ry, TF Term = " << this->tf_->tf_energy
522+
<< " Ry." << std::endl;
523+
509524
if (this->ml_->ml_energy >= this->tf_->tf_energy)
510525
{
511526
std::cout << "WARNING: ML >= TF" << std::endl;
512527
}
513528
}
514529

530+
//------------------------------------------------------------------
515531
// Generate data if needed
532+
//------------------------------------------------------------------
516533
if (PARAM.inp.of_ml_gene_data)
517534
{
518535
this->pelec->pot->update_from_charge(pelec->charge, &ucell); // Hartree + XC + external
@@ -533,8 +550,6 @@ void ESolver_OF::after_opt(const int istep, UnitCell& ucell, const bool conv_eso
533550
this->ml_->generateTrainData(pelec->charge->rho, *(this->wt_), *(this->tf_), this->pw_rho, vr_eff);
534551
}
535552
#endif
536-
// 2) call after_scf() of ESolver_FP
537-
ESolver_FP::after_scf(ucell, istep, conv_esolver);
538553

539554
ModuleBase::timer::tick("ESolver_OF", "after_opt");
540555
}

0 commit comments

Comments
 (0)