Skip to content

Commit d89f9a3

Browse files
authored
Merge branch 'deepmodeling:develop' into TDDFT_GPU_phase_1
2 parents e3c493d + 806de4a commit d89f9a3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+659
-729
lines changed

source/module_cell/klist.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ void K_Vectors::set(const UnitCell& ucell,
148148
// It's very important in parallel case,
149149
// firstly do the mpi_k() and then
150150
// do set_kup_and_kdw()
151-
GlobalC::Pkpoints.kinfo(nkstot,
152-
GlobalV::KPAR,
153-
GlobalV::MY_POOL,
154-
GlobalV::RANK_IN_POOL,
155-
GlobalV::NPROC,
156-
nspin_in); // assign k points to several process pools
151+
this->para_k.kinfo(nkstot,
152+
GlobalV::KPAR,
153+
GlobalV::MY_POOL,
154+
GlobalV::RANK_IN_POOL,
155+
GlobalV::NPROC,
156+
nspin_in); // assign k points to several process pools
157157
#ifdef __MPI
158158
// distribute K point data to the corresponding process
159159
this->mpi_k(); // 2008-4-29
@@ -1163,7 +1163,7 @@ void K_Vectors::mpi_k()
11631163

11641164
Parallel_Common::bcast_double(koffset, 3);
11651165

1166-
this->nks = GlobalC::Pkpoints.nks_pool[GlobalV::MY_POOL];
1166+
this->nks = this->para_k.nks_pool[GlobalV::MY_POOL];
11671167

11681168
GlobalV::ofs_running << std::endl;
11691169
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "k-point number in this process", nks);
@@ -1217,7 +1217,7 @@ void K_Vectors::mpi_k()
12171217
for (int i = 0; i < nks; i++)
12181218
{
12191219
// 3 is because each k point has three value:kx, ky, kz
1220-
k_index = i + GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL];
1220+
k_index = i + this->para_k.startk_pool[GlobalV::MY_POOL];
12211221
kvec_c[i].x = kvec_c_aux[k_index * 3];
12221222
kvec_c[i].y = kvec_c_aux[k_index * 3 + 1];
12231223
kvec_c[i].z = kvec_c_aux[k_index * 3 + 2];

source/module_cell/klist.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "module_base/global_variable.h"
66
#include "module_base/matrix3.h"
77
#include "module_cell/unitcell.h"
8-
8+
#include "parallel_kpoints.h"
99
#include <vector>
1010

1111
class K_Vectors
@@ -31,6 +31,9 @@ class K_Vectors
3131
K_Vectors& operator=(const K_Vectors&) = default;
3232
K_Vectors& operator=(K_Vectors&& rhs) = default;
3333

34+
Parallel_Kpoints para_k; ///< parallel for kpoints
35+
36+
3437
/**
3538
* @brief Set up the k-points for the system.
3639
*

source/module_cell/parallel_kpoints.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
#include "module_base/parallel_common.h"
44
#include "module_base/parallel_global.h"
55

6-
Parallel_Kpoints::Parallel_Kpoints()
7-
{
8-
}
9-
10-
Parallel_Kpoints::~Parallel_Kpoints()
11-
{
12-
}
13-
146
// the kpoints here are reduced after symmetry applied.
157
void Parallel_Kpoints::kinfo(int& nkstot_in,
168
const int& kpar_in,
@@ -227,7 +219,7 @@ void Parallel_Kpoints::pool_collection(double* value_re,
227219
return;
228220
}
229221

230-
void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik)
222+
void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const
231223
{
232224
const int dim2 = w.getBound2();
233225
const int dim3 = w.getBound3();
@@ -237,7 +229,7 @@ void Parallel_Kpoints::pool_collection(std::complex<double>* value, const Module
237229
}
238230

239231
template <class T, class V>
240-
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik)
232+
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const
241233
{
242234
#ifdef __MPI
243235
const int ik_now = ik - this->startk_pool[this->my_pool];

source/module_cell/parallel_kpoints.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
class Parallel_Kpoints
1010
{
1111
public:
12-
Parallel_Kpoints();
13-
~Parallel_Kpoints();
12+
Parallel_Kpoints(){};
13+
~Parallel_Kpoints(){};
1414

1515
void kinfo(int& nkstot_in,
1616
const int& kpar_in,
@@ -28,9 +28,9 @@ class Parallel_Kpoints
2828
const ModuleBase::realArray& a,
2929
const ModuleBase::realArray& b,
3030
const int& ik);
31-
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik);
31+
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const;
3232
template <class T, class V>
33-
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik);
33+
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const;
3434
#ifdef __MPI
3535
/**
3636
* @brief gather kpoints from all processors
@@ -46,8 +46,8 @@ class Parallel_Kpoints
4646
// int* nproc_pool = nullptr; it is not used
4747

4848
// inforamation about kpoints, dim: KPAR
49-
std::vector<int> nks_pool; // number of k-points in each pool
50-
std::vector<int> startk_pool; // the first k-point in each pool
49+
std::vector<int> nks_pool; // number of k-points in each pool, here use k-points without spin
50+
std::vector<int> startk_pool; // the first k-point in each pool, here use k-points without spin
5151

5252
// information about which pool each k-point belongs to,
5353
std::vector<int> whichpool; // whichpool[k] : the pool which k belongs to, dim: nkstot_np

source/module_elecstate/elecstate.h

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#ifndef ELECSTATE_H
22
#define ELECSTATE_H
3-
#include "module_parameter/parameter.h"
4-
53
#include "fp_energy.h"
64
#include "module_cell/klist.h"
75
#include "module_elecstate/module_charge/charge.h"
6+
#include "module_parameter/parameter.h"
87
#include "module_psi/psi.h"
98
#include "potentials/potential_new.h"
109

@@ -14,10 +13,10 @@ namespace elecstate
1413
class ElecState
1514
{
1615
public:
17-
ElecState(){}
18-
ElecState(Charge* charge_in,
19-
ModulePW::PW_Basis* rhopw_in,
20-
ModulePW::PW_Basis_Big* bigpw_in)
16+
ElecState()
17+
{
18+
}
19+
ElecState(Charge* charge_in, ModulePW::PW_Basis* rhopw_in, ModulePW::PW_Basis_Big* bigpw_in)
2120
{
2221
this->charge = charge_in;
2322
this->charge->set_rhopw(rhopw_in);
@@ -26,20 +25,20 @@ class ElecState
2625
}
2726
virtual ~ElecState()
2827
{
29-
if(this->pot != nullptr)
28+
if (this->pot != nullptr)
3029
{
3130
delete this->pot;
3231
this->pot = nullptr;
3332
}
3433
}
35-
void init_ks(Charge *chg_in, // pointer for class Charge
36-
const K_Vectors *klist_in,
37-
int nk_in, // number of k points
38-
ModulePW::PW_Basis* rhopw_in,
39-
const ModulePW::PW_Basis_Big* bigpw_in);
34+
void init_ks(Charge* chg_in, // pointer for class Charge
35+
const K_Vectors* klist_in,
36+
int nk_in, // number of k points
37+
ModulePW::PW_Basis* rhopw_in,
38+
const ModulePW::PW_Basis_Big* bigpw_in);
4039

4140
// return current electronic density rho, as a input for constructing Hamiltonian
42-
virtual const double *getRho(int spin) const;
41+
virtual const double* getRho(int spin) const;
4342

4443
// calculate electronic charge density on grid points or density matrix in real space
4544
// the consequence charge density rho saved into rho_out, preparing for charge mixing.
@@ -78,17 +77,14 @@ class ElecState
7877

7978
// use occupied weights from INPUT and skip calculate_weights
8079
// mohan updated on 2024-06-08
81-
void fixed_weights(
82-
const std::vector<double>& ocp_kb,
83-
const int &nbands,
84-
const double &nelec);
80+
void fixed_weights(const std::vector<double>& ocp_kb, const int& nbands, const double& nelec);
8581

86-
// if nupdown is not 0(TWO_EFERMI case),
87-
// nelec_spin will be fixed and weights will be constrained
82+
// if nupdown is not 0(TWO_EFERMI case),
83+
// nelec_spin will be fixed and weights will be constrained
8884
void init_nelec_spin();
89-
//used to record number of electrons per spin index
90-
//for NSPIN=2, it will record number of spin up and number of spin down
91-
//for NSPIN=4, it will record total number, magnetization for x, y, z direction
85+
// used to record number of electrons per spin index
86+
// for NSPIN=2, it will record number of spin up and number of spin down
87+
// for NSPIN=4, it will record total number, magnetization for x, y, z direction
9288
std::vector<double> nelec_spin;
9389

9490
virtual void print_psi(const psi::Psi<double>& psi_in, const int istep = -1)
@@ -102,7 +98,7 @@ class ElecState
10298

10399
/**
104100
* @brief Init rho_core, init rho, renormalize rho, init pot
105-
*
101+
*
106102
* @param istep i-th step
107103
* @param ucell unit cell
108104
* @param strucfac structure factor
@@ -142,28 +138,24 @@ class ElecState
142138
void set_exx(const std::complex<double>& Eexx);
143139
#endif //__LCAO
144140
#endif //__EXX
145-
141+
146142
double get_hartree_energy();
147143
double get_etot_efield();
148144
double get_etot_gatefield();
149145

150146
double get_solvent_model_Ael();
151147
double get_solvent_model_Acav();
152148

153-
virtual double get_spin_constrain_energy() {
149+
virtual double get_spin_constrain_energy()
150+
{
154151
return 0.0;
155152
}
156153

157154
double get_dftu_energy();
158155
double get_local_pp_energy();
159156

160-
#ifdef __DEEPKS
161-
double get_deepks_E_delta();
162-
double get_deepks_E_delta_band();
163-
#endif
164-
165-
fenergy f_en; ///< energies contribute to the total free energy
166-
efermi eferm; ///< fermi energies
157+
fenergy f_en; ///< energies contribute to the total free energy
158+
efermi eferm; ///< fermi energies
167159

168160
// below defines the bandgap:
169161

source/module_elecstate/elecstate_energy.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#include <cmath>
2-
31
#include "elecstate.h"
42
#include "elecstate_getters.h"
53
#include "module_base/global_variable.h"
64
#include "module_base/parallel_reduce.h"
75
#include "module_parameter/parameter.h"
6+
7+
#include <cmath>
88
#ifdef USE_PAW
99
#include "module_hamilt_general/module_xc/xc_functional.h"
1010
#include "module_hamilt_pw/hamilt_pwdft/global.h"
@@ -103,10 +103,9 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
103103
const double* v_eff = this->pot->get_effective_v(0);
104104
const double* v_fixed = this->pot->get_fixed_v();
105105
const double* v_ofk = nullptr;
106-
const bool v_ofk_flag =(get_xc_func_type() == 3
107-
|| get_xc_func_type() == 5);
106+
const bool v_ofk_flag = (get_xc_func_type() == 3 || get_xc_func_type() == 5);
108107
#ifdef USE_PAW
109-
if(PARAM.inp.use_paw)
108+
if (PARAM.inp.use_paw)
110109
{
111110
ModuleBase::matrix v_xc;
112111
const std::tuple<double, double, ModuleBase::matrix> etxc_vtxc_v
@@ -115,19 +114,19 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
115114

116115
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
117116
{
118-
deband_aux -= this->charge->rho[0][ir] * v_xc(0,ir);
117+
deband_aux -= this->charge->rho[0][ir] * v_xc(0, ir);
119118
}
120119
if (PARAM.inp.nspin == 2)
121120
{
122121
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
123122
{
124-
deband_aux -= this->charge->rho[1][ir] * v_xc(1,ir);
123+
deband_aux -= this->charge->rho[1][ir] * v_xc(1, ir);
125124
}
126125
}
127126
}
128127
#endif
129128

130-
if(!PARAM.inp.use_paw)
129+
if (!PARAM.inp.use_paw)
131130
{
132131
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
133132
{
@@ -137,16 +136,16 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
137136
{
138137
v_ofk = this->pot->get_effective_vofk(0);
139138
// cause in the get_effective_vofk, the func will return nullptr
140-
if(v_ofk==nullptr && this->charge->rhopw->nrxx>0)
139+
if (v_ofk == nullptr && this->charge->rhopw->nrxx > 0)
141140
{
142-
ModuleBase::WARNING_QUIT("ElecState::cal_delta_eband","v_ofk is nullptr");
141+
ModuleBase::WARNING_QUIT("ElecState::cal_delta_eband", "v_ofk is nullptr");
143142
}
144143
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
145144
{
146145
deband_aux -= this->charge->kin_r[0][ir] * v_ofk[ir];
147146
}
148147
}
149-
148+
150149
if (PARAM.inp.nspin == 2)
151150
{
152151
v_eff = this->pot->get_effective_v(1);
@@ -157,9 +156,9 @@ double ElecState::cal_delta_eband(const UnitCell& ucell) const
157156
if (v_ofk_flag)
158157
{
159158
v_ofk = this->pot->get_effective_vofk(1);
160-
if(v_ofk==nullptr && this->charge->rhopw->nrxx>0)
159+
if (v_ofk == nullptr && this->charge->rhopw->nrxx > 0)
161160
{
162-
ModuleBase::WARNING_QUIT("ElecState::cal_delta_eband","v_ofk is nullptr");
161+
ModuleBase::WARNING_QUIT("ElecState::cal_delta_eband", "v_ofk is nullptr");
163162
}
164163
for (int ir = 0; ir < this->charge->rhopw->nrxx; ir++)
165164
{
@@ -219,7 +218,7 @@ double ElecState::cal_delta_escf() const
219218
if (get_xc_func_type() == 3 || get_xc_func_type() == 5)
220219
{
221220
// cause in the get_effective_vofk, the func will return nullptr
222-
assert(v_ofk!=nullptr);
221+
assert(v_ofk != nullptr);
223222
descf -= (this->charge->kin_r[0][ir] - this->charge->kin_r_save[0][ir]) * v_ofk[ir];
224223
}
225224
}
@@ -278,7 +277,7 @@ void ElecState::cal_converged()
278277
/**
279278
* @brief calculate energies
280279
*
281-
* @param type: 1 means Harris-Foulkes functinoal;
280+
* @param type: 1 means Harris-Foulkes functinoal;
282281
* @param type: 2 means Kohn-Sham functional;
283282
*/
284283
void ElecState::cal_energies(const int type)
@@ -292,7 +291,7 @@ void ElecState::cal_energies(const int type)
292291
//! energy from gate-field
293292
this->f_en.gatefield = get_etot_gatefield();
294293

295-
//! energy from implicit solvation model
294+
//! energy from implicit solvation model
296295
if (PARAM.inp.imp_sol)
297296
{
298297
this->f_en.esol_el = get_solvent_model_Ael();
@@ -305,27 +304,19 @@ void ElecState::cal_energies(const int type)
305304
this->f_en.escon = get_spin_constrain_energy();
306305
}
307306

308-
// energy from DFT+U
307+
// energy from DFT+U
309308
if (PARAM.inp.dft_plus_u)
310309
{
311310
this->f_en.edftu = get_dftu_energy();
312311
}
313312

314313
this->f_en.e_local_pp = get_local_pp_energy();
315314

316-
#ifdef __DEEPKS
317-
// energy from deepks
318-
if (PARAM.inp.deepks_scf)
319-
{
320-
this->f_en.edeepks_scf = get_deepks_E_delta() - get_deepks_E_delta_band();
321-
}
322-
#endif
323-
324315
if (type == 1) // Harris-Foulkes functional
325316
{
326317
this->f_en.calculate_harris();
327318
}
328-
else if (type == 2)// Kohn-Sham functional
319+
else if (type == 2) // Kohn-Sham functional
329320
{
330321
this->f_en.calculate_etot();
331322
}

0 commit comments

Comments
 (0)