Skip to content

Commit 00dcc80

Browse files
authored
Refactor: Move cal_tau to esolver_ks.cpp. (#6135)
* move cal_tau to esolver_ks.cpp * Fix: Fix the compile error with CUDA * Fix the compile error agian
1 parent 274477a commit 00dcc80

File tree

4 files changed

+38
-35
lines changed

4 files changed

+38
-35
lines changed

source/module_elecstate/elecstate.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class ElecState
6060
{
6161
return;
6262
}
63+
virtual void cal_tau(const psi::Psi<std::complex<float>>& psi)
64+
{
65+
return;
66+
}
6367

6468
// update charge density for next scf step
6569
// in this function, 1. input rho for construct Hamilt and 2. calculated rho from Psi will mix to 3. new charge

source/module_esolver/esolver_ks.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -714,11 +714,18 @@ template <typename T, typename Device>
714714
void ESolver_KS<T, Device>::after_scf(UnitCell& ucell, const int istep, const bool conv_esolver)
715715
{
716716
ModuleBase::TITLE("ESolver_KS", "after_scf");
717-
718-
// 1) call after_scf() of ESolver_FP
717+
718+
// 1) calculate the kinetic energy density tau
719+
if (PARAM.inp.out_elf[0] > 0)
720+
{
721+
assert(this->psi != nullptr);
722+
this->pelec->cal_tau(*(this->psi));
723+
}
724+
725+
// 2) call after_scf() of ESolver_FP
719726
ESolver_FP::after_scf(ucell, istep, conv_esolver);
720727

721-
// 2) write eigenvalues
728+
// 3) write eigenvalues
722729
if (istep % PARAM.inp.out_interval == 0)
723730
{
724731
elecstate::print_eigenvalue(this->pelec->ekb,this->pelec->wg,this->pelec->klist,GlobalV::ofs_running);

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,12 +640,14 @@ void ESolver_KS_PW<T, Device>::after_scf(UnitCell& ucell, const int istep, const
640640
ModuleBase::timer::tick("ESolver_KS_PW", "after_scf");
641641

642642
//------------------------------------------------------------------
643-
// 1) calculate the kinetic energy density tau in pw basis
644-
// sunliang 2024-09-18
643+
// 1) since ESolver_KS::psi is hidden by ESolver_KS_PW::psi,
644+
// we need to copy the data from ESolver_KS::psi to ESolver_KS_PW::psi.
645+
// This part needs to be removed when we have a better design.
646+
// sunliang 2025-04-10
645647
//------------------------------------------------------------------
646648
if (PARAM.inp.out_elf[0] > 0)
647649
{
648-
this->pelec->cal_tau(*(this->psi));
650+
this->ESolver_KS<T, Device>::psi = new psi::Psi<T>(this->psi[0]);
649651
}
650652

651653
//------------------------------------------------------------------

source/module_esolver/lcao_after_scf.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,12 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
7878
ModuleBase::timer::tick("ESolver_KS_LCAO", "after_scf");
7979

8080
//------------------------------------------------------------------
81-
//! 1) calculate the kinetic energy density tau in LCAO basis
82-
//!sunliang 2024-09-18
83-
//------------------------------------------------------------------
84-
if (PARAM.inp.out_elf[0] > 0)
85-
{
86-
assert(this->psi != nullptr);
87-
this->pelec->cal_tau(*(this->psi));
88-
}
89-
90-
//------------------------------------------------------------------
91-
//! 2) call after_scf() of ESolver_KS
81+
//! 1) call after_scf() of ESolver_KS
9282
//------------------------------------------------------------------
9383
ESolver_KS<TK>::after_scf(ucell, istep, conv_esolver);
9484

9585
//------------------------------------------------------------------
96-
//! 3) write density matrix for sparse matrix in LCAO basis
86+
//! 2) write density matrix for sparse matrix in LCAO basis
9787
//------------------------------------------------------------------
9888
ModuleIO::write_dmr(dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMR_vector(),
9989
this->pv,
@@ -105,7 +95,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
10595
istep);
10696

10797
//------------------------------------------------------------------
108-
//! 4) write density matrix in LCAO basis
98+
//! 3) write density matrix in LCAO basis
10999
//------------------------------------------------------------------
110100
if (PARAM.inp.out_dm)
111101
{
@@ -124,7 +114,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
124114

125115
#ifdef __EXX
126116
//------------------------------------------------------------------
127-
//! 5) write Hexx matrix in LCAO basis
117+
//! 4) write Hexx matrix in LCAO basis
128118
// (see `out_chg` in docs/advanced/input_files/input-main.md)
129119
//------------------------------------------------------------------
130120
if (PARAM.inp.calculation != "nscf")
@@ -146,7 +136,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
146136
#endif
147137

148138
//------------------------------------------------------------------
149-
// 6) write Hamiltonian and Overlap matrix in LCAO basis
139+
// 5) write Hamiltonian and Overlap matrix in LCAO basis
150140
//------------------------------------------------------------------
151141
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
152142
{
@@ -193,7 +183,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
193183
}
194184

195185
//------------------------------------------------------------------
196-
// 7) write electronic wavefunctions in LCAO basis
186+
// 6) write electronic wavefunctions in LCAO basis
197187
//------------------------------------------------------------------
198188
if (elecstate::ElecStateLCAO<TK>::out_wfc_lcao && (istep % PARAM.inp.out_interval == 0))
199189
{
@@ -207,7 +197,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
207197
}
208198

209199
//------------------------------------------------------------------
210-
//! 8) write DeePKS information in LCAO basis
200+
//! 7) write DeePKS information in LCAO basis
211201
//------------------------------------------------------------------
212202
#ifdef __DEEPKS
213203
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
@@ -234,7 +224,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
234224
#endif
235225

236226
//------------------------------------------------------------------
237-
//! 9) Perform RDMFT calculations
227+
//! 8) Perform RDMFT calculations
238228
// rdmft, added by jghan, 2024-10-17
239229
//------------------------------------------------------------------
240230
if (PARAM.inp.rdmft == true)
@@ -263,7 +253,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
263253

264254
#ifdef __EXX
265255
//------------------------------------------------------------------
266-
// 10) Write RPA information in LCAO basis
256+
// 9) Write RPA information in LCAO basis
267257
//------------------------------------------------------------------
268258
if (PARAM.inp.rpa)
269259
{
@@ -279,7 +269,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
279269
#endif
280270

281271
//------------------------------------------------------------------
282-
// 11) write HR in npz format in LCAO basis
272+
// 10) write HR in npz format in LCAO basis
283273
//------------------------------------------------------------------
284274
if (PARAM.inp.out_hr_npz)
285275
{
@@ -300,7 +290,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
300290
}
301291

302292
//------------------------------------------------------------------
303-
// 12) write density matrix in the 'npz' format in LCAO basis
293+
// 11) write density matrix in the 'npz' format in LCAO basis
304294
//------------------------------------------------------------------
305295
if (PARAM.inp.out_dm_npz)
306296
{
@@ -317,7 +307,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
317307
}
318308

319309
//------------------------------------------------------------------
320-
//! 13) Print out information every 'out_interval' steps.
310+
//! 12) Print out information every 'out_interval' steps.
321311
//------------------------------------------------------------------
322312
if (PARAM.inp.calculation != "md" || istep % PARAM.inp.out_interval == 0)
323313
{
@@ -354,7 +344,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
354344
}
355345

356346
//------------------------------------------------------------------
357-
//! 14) Print out atomic magnetization in LCAO basis
347+
//! 13) Print out atomic magnetization in LCAO basis
358348
//! only when 'spin_constraint' is on.
359349
//------------------------------------------------------------------
360350
if (PARAM.inp.sc_mag_switch)
@@ -366,7 +356,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
366356
}
367357

368358
//------------------------------------------------------------------
369-
//! 15) Print out kinetic matrix in LCAO basis
359+
//! 14) Print out kinetic matrix in LCAO basis
370360
//------------------------------------------------------------------
371361
if (PARAM.inp.out_mat_tk[0])
372362
{
@@ -402,7 +392,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
402392
}
403393

404394
//------------------------------------------------------------------
405-
//! 16) wannier90 interface in LCAO basis
395+
//! 15) wannier90 interface in LCAO basis
406396
// added by jingan in 2018.11.7
407397
//------------------------------------------------------------------
408398
if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90)
@@ -444,7 +434,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
444434
}
445435

446436
//------------------------------------------------------------------
447-
//! 17) berry phase calculations in LCAO basis
437+
//! 16) berry phase calculations in LCAO basis
448438
// added by jingan
449439
//------------------------------------------------------------------
450440
if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1)
@@ -458,7 +448,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
458448
}
459449

460450
//------------------------------------------------------------------
461-
//! 18) calculate quasi-orbitals in LCAO basis
451+
//! 17) calculate quasi-orbitals in LCAO basis
462452
//------------------------------------------------------------------
463453
if (PARAM.inp.qo_switch)
464454
{
@@ -475,15 +465,15 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
475465
}
476466

477467
//------------------------------------------------------------------
478-
//! 19) Clean up RA, which is used to serach for adjacent atoms
468+
//! 18) Clean up RA, which is used to serach for adjacent atoms
479469
//------------------------------------------------------------------
480470
if (!PARAM.inp.cal_force && !PARAM.inp.cal_stress)
481471
{
482472
RA.delete_grid();
483473
}
484474

485475
//------------------------------------------------------------------
486-
//! 20) calculate expectation of angular momentum operator in LCAO basis
476+
//! 19) calculate expectation of angular momentum operator in LCAO basis
487477
//------------------------------------------------------------------
488478
if (PARAM.inp.out_mat_l[0])
489479
{

0 commit comments

Comments
 (0)