Skip to content

Commit bda3d33

Browse files
committed
move calculate_weight() out of psi2rho()
1 parent 2e60ec1 commit bda3d33

17 files changed

+95
-66
lines changed

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "elecstate_lcao.h"
22

3-
#include <vector>
4-
53
#include "cal_dm.h"
64
#include "module_base/timer.h"
75
#include "module_elecstate/module_dm/cal_dm_psi.h"
@@ -11,6 +9,8 @@
119
#include "module_hamilt_pw/hamilt_pwdft/global.h"
1210
#include "module_parameter/parameter.h"
1311

12+
#include <vector>
13+
1414
namespace elecstate
1515
{
1616

@@ -21,34 +21,31 @@ void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<d
2121
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
2222
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
2323

24-
this->calculate_weights();
25-
26-
// the calculations of dm, and dm -> rho are, technically, two separate
27-
// functionalities, as we cannot rule out the possibility that we may have a
28-
// dm from other sources, such as read from file. However, since we are not
29-
// separating them now, I opt to add a flag to control how dm is obtained as
30-
// of now
31-
if (!PARAM.inp.dm_to_rho)
32-
{
33-
this->calEBand();
34-
35-
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
36-
37-
// this part for calculating DMK in 2d-block format, not used for charge
38-
// now
39-
// psi::Psi<std::complex<double>> dm_k_2d();
40-
41-
if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
42-
|| PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp"
43-
|| PARAM.inp.ks_solver == "cg_in_lcao") // Peize Lin test 2019-05-15
44-
{
45-
elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
46-
this->wg,
47-
psi,
48-
*(this->DM));
49-
this->DM->cal_DMR();
50-
}
51-
}
24+
// // the calculations of dm, and dm -> rho are, technically, two separate
25+
// // functionalities, as we cannot rule out the possibility that we may have a
26+
// // dm from other sources, such as read from file. However, since we are not
27+
// // separating them now, I opt to add a flag to control how dm is obtained as
28+
// // of now
29+
// if (!PARAM.inp.dm_to_rho)
30+
// {
31+
// ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
32+
33+
// // this part for calculating DMK in 2d-block format, not used for charge
34+
// // now
35+
// // psi::Psi<std::complex<double>> dm_k_2d();
36+
37+
// if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver ==
38+
// "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
39+
// || PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp"
40+
// || PARAM.inp.ks_solver == "cg_in_lcao") // Peize Lin test 2019-05-15
41+
// {
42+
// elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
43+
// this->wg,
44+
// psi,
45+
// *(this->DM));
46+
// this->DM->cal_DMR();
47+
// }
48+
// }
5249

5350
for (int is = 0; is < PARAM.inp.nspin; is++)
5451
{
@@ -83,23 +80,6 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
8380
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
8481
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
8582

86-
this->calculate_weights();
87-
this->calEBand();
88-
89-
if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
90-
|| PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp" || PARAM.inp.ks_solver == "cg_in_lcao")
91-
{
92-
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
93-
94-
// get DMK in 2d-block format
95-
elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
96-
this->wg,
97-
psi,
98-
*(this->DM));
99-
this->DM->cal_DMR();
100-
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
101-
}
102-
10383
for (int is = 0; is < PARAM.inp.nspin; is++)
10484
{
10585
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is],

source/module_elecstate/elecstate_lcao.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class ElecStateLCAO : public ElecState
7474
void dmToRho(std::vector<TK*> pexsi_DM, std::vector<TK*> pexsi_EDM);
7575
#endif
7676

77+
DensityMatrix<TK, double>* DM = nullptr;
78+
7779
protected:
7880
// calculate electronic charge density on grid points or density matrix in real space
7981
// the consequence charge density rho saved into rho_out, preparing for charge mixing.
@@ -85,7 +87,6 @@ class ElecStateLCAO : public ElecState
8587

8688
Gint_Gamma* gint_gamma = nullptr; // mohan add 2024-04-01
8789
Gint_k* gint_k = nullptr; // mohan add 2024-04-01
88-
DensityMatrix<TK, double>* DM = nullptr;
8990
};
9091

9192
template <typename TK>

source/module_elecstate/elecstate_pw.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
8989
ModuleBase::timer::tick("ElecStatePW", "psiToRho");
9090

9191
this->init_rho_data();
92-
this->calculate_weights();
93-
94-
this->calEBand();
9592

9693
for(int is=0; is<PARAM.inp.nspin; is++)
9794
{

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
1717

1818
if (GlobalV::MY_STOGROUP == 0)
1919
{
20-
this->calEBand();
21-
2220
for (int is = 0; is < nspin; is++)
2321
{
2422
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);

source/module_esolver/esolver_ks.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ void ESolver_KS<T, Device>::hamilt2density(const int istep, const int iter, cons
402402

403403
drho = p_chgmix->get_drho(pelec->charge, PARAM.inp.nelec);
404404
hsolver_error = 0.0;
405-
if (iter == 1)
405+
if (iter == 1 && PARAM.inp.calculation != "nscf")
406406
{
407407
hsolver_error
408408
= hsolver::cal_hsolve_error(PARAM.inp.basis_type, PARAM.inp.esolver_type, diag_ethr, PARAM.inp.nelec);
@@ -628,7 +628,7 @@ void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)
628628

629629
// If drho < hsolver_error in the first iter or drho < scf_thr, we
630630
// do not change rho.
631-
if (drho < hsolver_error || this->conv_esolver)
631+
if (drho < hsolver_error || this->conv_esolver || PARAM.inp.calculation == "nscf")
632632
{
633633
if (drho < hsolver_error)
634634
{

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "module_base/formatter.h"
44
#include "module_base/global_variable.h"
55
#include "module_base/tool_title.h"
6+
#include "module_elecstate/module_dm/cal_dm_psi.h"
67
#include "module_io/berryphase.h"
78
#include "module_io/cube_io.h"
89
#include "module_io/dos_nao.h"
@@ -51,7 +52,7 @@
5152
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
5253
#include "module_hsolver/hsolver_lcao.h"
5354
// function used by deepks
54-
#include "module_elecstate/cal_dm.h"
55+
// #include "module_elecstate/cal_dm.h"
5556
//---------------------------------------------------
5657

5758
#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h"
@@ -590,6 +591,14 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
590591
// and the ncalculate the charge density on grid.
591592

592593
this->pelec->skip_weights = true;
594+
this->pelec->calculate_weights();
595+
if (!PARAM.inp.dm_to_rho)
596+
{
597+
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec);
598+
_pelec->calEBand();
599+
elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM));
600+
_pelec->DM->cal_DMR();
601+
}
593602
this->pelec->psiToRho(*this->psi);
594603
this->pelec->skip_weights = false;
595604

@@ -718,9 +727,9 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density_single(int istep, int iter, double
718727
// reset energy
719728
this->pelec->f_en.eband = 0.0;
720729
this->pelec->f_en.demet = 0.0;
721-
730+
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
722731
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
723-
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, false);
732+
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, skip_charge);
724733

725734
// 5) what's the exd used for?
726735
#ifdef __EXX

source/module_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density_single(const int istep, const int ite
165165
this->pelec->f_en.demet = 0.0;
166166
if (this->psi != nullptr)
167167
{
168+
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
168169
hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver);
169-
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, false);
170+
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, skip_charge);
170171
}
171172
}
172173

source/module_esolver/esolver_ks_lcaopw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ namespace ModuleESolver
126126
hsolver::DiagoIterAssist<T>::SCF_ITER = iter;
127127
hsolver::DiagoIterAssist<T>::PW_DIAG_THR = ethr;
128128
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
129+
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
129130

130131
// It is not a good choice to overload another solve function here, this will spoil the concept of
131132
// multiple inheritance and polymorphism. But for now, we just do it in this way.
@@ -138,7 +139,7 @@ namespace ModuleESolver
138139
}
139140

140141
hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
141-
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], false);
142+
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge);
142143

143144
// add exx
144145
#ifdef __EXX

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
354354
hsolver::DiagoIterAssist<T, Device>::SCF_ITER = iter;
355355
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
356356
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
357+
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
357358

358359
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc,
359360
&this->wf,
@@ -375,7 +376,7 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(const int istep, const int
375376
this->pelec->ekb.c,
376377
GlobalV::RANK_IN_POOL,
377378
GlobalV::NPROC_IN_POOL,
378-
false);
379+
skip_charge);
379380

380381
this->init_psi = true;
381382

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
174174
this->pelec->f_en.demet = 0.0;
175175
// choose if psi should be diag in subspace
176176
// be careful that istep start from 0 and iter start from 1
177-
if (istep == 0 && iter == 1)
177+
if (istep == 0 && iter == 1 || PARAM.inp.calculation == "nscf")
178178
{
179179
hsolver::DiagoIterAssist<T, Device>::need_subspace = false;
180180
}
@@ -183,8 +183,8 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
183183
hsolver::DiagoIterAssist<T, Device>::need_subspace = true;
184184
}
185185

186+
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
186187
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
187-
188188
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
189189

190190
// hsolver only exists in this function
@@ -206,7 +206,15 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
206206
hsolver::DiagoIterAssist<T, Device>::need_subspace,
207207
this->init_psi);
208208

209-
hsolver_pw_sdft_obj.solve(this->p_hamilt, this->kspw_psi[0], this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false);
209+
hsolver_pw_sdft_obj.solve(this->p_hamilt,
210+
this->kspw_psi[0],
211+
this->psi[0],
212+
this->pelec,
213+
this->pw_wfc,
214+
this->stowf,
215+
istep,
216+
iter,
217+
skip_charge);
210218
this->init_psi = true;
211219

212220
// set_diagethr need it

0 commit comments

Comments
 (0)