Skip to content

Commit 3cb80dc

Browse files
committed
Refactor: finished HSolverLCAO class, gamma_only and multi-k lines
1 parent 9d9580e commit 3cb80dc

File tree

10 files changed

+129
-44
lines changed

10 files changed

+129
-44
lines changed

source/module_elecstate/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ add_library(
44
elecstate.cpp
55
elecstate_pw.cpp
66
elecstate_lcao.cpp
7+
dm2d_to_grid.cpp
78
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "src_lcao/local_orbital_charge.h"
2+
#include "module_psi/psi.h"
3+
#include "module_base/timer.h"
4+
5+
/// transformation from 2d block density matrix to grid points for one k point
6+
void Local_Orbital_Charge::dm2dToGrid(const psi::Psi<double>& dm2d, double** dm_grid)
7+
{
8+
ModuleBase::timer::tick("Local_Orbital_Charge","dm_2dTOgrid");
9+
10+
#ifdef __MPI
11+
// put data from dm_gamma[ik] to sender index
12+
int nNONZERO=0;
13+
for(int i=0; i<this->sender_size; ++i)
14+
{
15+
const int idx=this->sender_2D_index[i];
16+
const int icol=idx%GlobalV::NLOCAL;
17+
const int irow=(idx-icol)/GlobalV::NLOCAL;
18+
// sender_buffer[i]=wfc_dm_2d.dm_gamma(ik, irow, icol);
19+
this->sender_buffer[i]=dm2d(icol,irow); // sender_buffer is clomun major,
20+
// so the row and column index should be switched
21+
if(this->sender_buffer[i]!=0) ++nNONZERO;
22+
}
23+
24+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"number of non-zero elements in sender_buffer",nNONZERO);
25+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"sender_size",this->sender_size);
26+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"last sender_buffer",this->sender_buffer[this->sender_size-1]);
27+
28+
// transform data via MPI_Alltoallv
29+
MPI_Alltoallv(this->sender_buffer, this->sender_size_process, this->sender_displacement_process, MPI_DOUBLE,
30+
this->receiver_buffer, this->receiver_size_process, this->receiver_displacement_process, MPI_DOUBLE, this->ParaV->comm_2D);
31+
32+
// put data from receiver buffer to dm_grid[ik]
33+
nNONZERO=0;
34+
for(int i=0; i<this->receiver_size; ++i)
35+
{
36+
const int idx=this->receiver_local_index[i];
37+
const int icol=idx%this->lgd_now;
38+
const int irow=(idx-icol)/this->lgd_now;
39+
dm_grid[irow][icol]=this->receiver_buffer[i];
40+
if(this->receiver_buffer[i]!=0) ++nNONZERO;
41+
}
42+
43+
44+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"number of non-zero elements in receiver_buffer",nNONZERO);
45+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"receiver_size",this->receiver_size);
46+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"last receiver_buffer",receiver_buffer[this->receiver_size-1]);
47+
#else
48+
for(int irow=0;irow<dm2d.get_nbasis();++irow)
49+
{
50+
for(int icol=0;icol<dm2d.get_nbands();++icol)
51+
{
52+
dm_grid[irow][icol] = dm2d(icol, irow);
53+
}
54+
}
55+
56+
#endif
57+
58+
ModuleBase::timer::tick("Local_Orbital_Charge","dm_2dTOgrid");
59+
return;
60+
}

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
namespace elecstate
77
{
8+
int ElecStateLCAO::out_wfc_lcao = 0;
89

910
//for gamma_only(double case) and multi-k(complex<double> case)
1011
template<typename T> void ElecStateLCAO::cal_dm(const ModuleBase::matrix& wg,
@@ -53,21 +54,37 @@ template<typename T> void ElecStateLCAO::cal_dm(const ModuleBase::matrix& wg,
5354
return;
5455
}
5556

57+
// multi-k case
5658
void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
5759
{
5860
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
5961
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
6062

63+
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
64+
65+
//this part for calculating dm_k in 2d-block format, not used for charge now
6166
psi::Psi<std::complex<double>> dm_k_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
6267

63-
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
64-
//this->loc->cal_dk_k(GlobalC::GridT);
6568
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
6669
|| GlobalV::KS_SOLVER == "lapack") // Peize Lin test 2019-05-15
6770
{
6871
this->cal_dm(this->wg, psi, dm_k_2d);
6972
}
7073

74+
//this part for steps:
75+
//1. psi_k transform from 2d-block to grid format
76+
//2. psi_k_grid -> DM_R
77+
//3. DM_R -> rho(r)
78+
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx")
79+
{
80+
for(int ik = 0; ik<psi.get_nk(); ik++)
81+
{
82+
psi.fix_k(ik);
83+
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), this->lowf->wfc_k_grid[ik], ik);
84+
}
85+
}
86+
87+
this->loc->cal_dk_k(GlobalC::GridT);
7188
for (int is = 0; is < GlobalV::NSPIN; is++)
7289
{
7390
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
@@ -78,7 +95,7 @@ void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
7895
//------------------------------------------------------------
7996

8097
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
81-
//uhm.GK.cal_rho_k(this->loc->DM_R);
98+
this->uhm->GK.cal_rho_k(this->loc->DM_R);
8299

83100
this->charge->renormalize_rho();
84101

@@ -91,12 +108,12 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
91108
{
92109
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
93110
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
111+
112+
this->calculate_weights();
94113

95114
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
96115
|| GlobalV::KS_SOLVER == "lapack")
97116
{
98-
// LiuXh modify 2021-09-06, clear memory, cal_dk_gamma() not used for genelpa solver.
99-
// density matrix has already been calculated.
100117
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
101118

102119
psi::Psi<double> dm_gamma_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
@@ -105,7 +122,17 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
105122

106123
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
107124

108-
//this->loc->cal_dk_gamma_from_2D(); // transform dm_gamma[is].c to this->loc->DM[is]
125+
for(int ik = 0; ik< psi.get_nk(); ++ik)
126+
{
127+
// for gamma_only case, no convertion occured, just for print.
128+
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx")
129+
{
130+
psi.fix_k(ik);
131+
double** wfc_grid = nullptr; // output but not do "2d-to-grid" conversion
132+
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), wfc_grid);
133+
}
134+
this->loc->dm2dToGrid(dm_gamma_2d, this->loc->DM[ik]); // transform dm_gamma[is].c to this->loc->DM[is]
135+
}
109136
}
110137

111138
for (int is = 0; is < GlobalV::NSPIN; is++)
@@ -117,7 +144,7 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
117144
// calculate the charge density on real space grid.
118145
//------------------------------------------------------------
119146
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
120-
//uhm.GG.cal_rho(this->loc->DM);
147+
this->uhm->GG.cal_rho(this->loc->DM);
121148

122149
this->charge->renormalize_rho();
123150

source/module_elecstate/elecstate_lcao.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,27 @@
33

44
#include "elecstate.h"
55
#include "src_lcao/local_orbital_charge.h"
6+
#include "src_lcao/local_orbital_wfc.h"
7+
#include "src_lcao/LCAO_hamilt.h"
68

79
namespace elecstate
810
{
911

1012
class ElecStateLCAO : public ElecState
1113
{
1214
public:
13-
ElecStateLCAO(Charge* chg_in, int nks_in, int nbands_in){init(chg_in, nks_in, nbands_in);}
15+
ElecStateLCAO(Charge* chg_in,
16+
int nks_in,
17+
int nbands_in,
18+
Local_Orbital_Charge* loc_in,
19+
LCAO_Hamilt *uhm_in,
20+
Local_Orbital_wfc* lowf_in)
21+
{
22+
init(chg_in, nks_in, nbands_in);
23+
this->loc = loc_in;
24+
this->uhm = uhm_in;
25+
this->lowf = lowf_in;
26+
}
1427
//void init(Charge* chg_in):charge(chg_in){} override;
1528

1629
// interface for HSolver to calculate rho from Psi
@@ -22,22 +35,26 @@ class ElecStateLCAO : public ElecState
2235
// update charge density for next scf step
2336
//void getNewRho() override;
2437

38+
static int out_wfc_lcao;
39+
2540
private:
2641

2742
// calculate electronic charge density on grid points or density matrix in real space
2843
// the consequence charge density rho saved into rho_out, preparing for charge mixing.
29-
void updateRhoK(const psi::Psi<std::complex<double>>& psi) ;//override;
44+
//void updateRhoK(const psi::Psi<std::complex<double>>& psi) ;//override;
3045
//sum over all pools for rho and ebands
31-
void parallelK();
46+
//void parallelK();
3247
// calcualte rho for each k
33-
void rhoBandK(const psi::Psi<std::complex<double>>& psi);
48+
//void rhoBandK(const psi::Psi<std::complex<double>>& psi);
3449

3550
template<typename T> void cal_dm(
3651
const ModuleBase::matrix& wg,
3752
const psi::Psi<T>& wfc,
3853
psi::Psi<T>& dm);
3954

4055
Local_Orbital_Charge* loc = nullptr;
56+
LCAO_Hamilt *uhm = nullptr;
57+
Local_Orbital_wfc* lowf = nullptr;
4158
};
4259

4360
} // namespace elecstate

source/module_hsolver/hsolver_lcao.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ void HSolverLCAO::solveTemplate(hamilt::Hamilt* pHamilt, psi::Psi<T>& psi, elecs
3434

3535
int HSolverLCAO::out_mat_hs = 0;
3636
int HSolverLCAO::out_mat_hsR = 0;
37-
int HSolverLCAO::out_wfc_lcao = 0;
3837

3938
void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, psi::Psi<std::complex<double>>& psi, elecstate::ElecState* pes)
4039
{
@@ -48,22 +47,11 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, psi::Psi<double>& psi, elecstat
4847
void HSolverLCAO::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi<std::complex<double>>& psi, double* eigenvalue)
4948
{
5049
pdiagh->diag(hm, psi, eigenvalue);
51-
if (this->method == "scalapack_gvx" || this->method == "genelpa")
52-
{
53-
int ik = psi.get_current_k();
54-
this->lowf->wfc_2d_to_grid(HSolverLCAO::out_wfc_lcao, psi.get_pointer(), this->lowf->wfc_k_grid[ik], ik);
55-
}
5650
}
5751

5852
void HSolverLCAO::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi<double>& psi, double* eigenvalue)
5953
{
6054
pdiagh->diag(hm, psi, eigenvalue);
61-
// for gamma_only case, no convertion occured, just for print.
62-
if (this->method == "scalapack_gvx" || this->method == "genelpa")
63-
{
64-
double** wfc_grid = nullptr; // output but not do "2d-to-grid" conversion
65-
this->lowf->wfc_2d_to_grid(HSolverLCAO::out_wfc_lcao, psi.get_pointer(), wfc_grid);
66-
}
6755
}
6856

6957
} // namespace hsolver

source/module_hsolver/hsolver_lcao.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define HSOLVERLCAO_H
33

44
#include "hsolver.h"
5-
#include "src_lcao/local_orbital_wfc.h"
65

76
namespace hsolver
87
{
@@ -11,13 +10,6 @@ class HSolverLCAO : public HSolver
1110
{
1211
public:
1312

14-
HSolverLCAO(
15-
Local_Orbital_wfc* lowf_in
16-
)
17-
{
18-
this->lowf = lowf_in;
19-
}
20-
2113
/*void init(
2214
const Basis* pbas
2315
//const Input &in,
@@ -35,7 +27,6 @@ class HSolverLCAO : public HSolver
3527
psi::Psi<double>& psi,
3628
elecstate::ElecState* pes) override;
3729

38-
static int out_wfc_lcao;
3930
static int out_mat_hs; // mohan add 2010-09-02
4031
static int out_mat_hsR; // LiuXh add 2019-07-16
4132

@@ -55,8 +46,6 @@ class HSolverLCAO : public HSolver
5546
psi::Psi<std::complex<double>>& psi,
5647
elecstate::ElecState* pes
5748
);*/
58-
59-
Local_Orbital_wfc* lowf = nullptr;
6049
};
6150

6251
}//namespace hsolver

source/src_lcao/local_orbital_charge.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
#include "src_lcao/record_adj.h"
1111
#include "src_lcao/local_orbital_wfc.h"
1212
#include "src_lcao/LCAO_hamilt.h"
13+
#include "module_psi/psi.h"
1314

1415
class Local_Orbital_Charge
1516
{
16-
1717
public:
1818

1919
Local_Orbital_Charge();
@@ -33,7 +33,9 @@ class Local_Orbital_Charge
3333
void gamma_file(const Grid_Technique& gt,
3434
Local_Orbital_wfc &lowf);
3535
void cal_dk_gamma_from_2D_pub(void);
36-
36+
//transformation from 2d block to grid, only gamma_only used it now
37+
//template<typename T>
38+
void dm2dToGrid(const psi::Psi<double>& dm2d, double** dm_grid);
3739

3840
//-----------------
3941
// in DM_k.cpp
@@ -87,6 +89,9 @@ class Local_Orbital_Charge
8789
//-----------------
8890
const Parallel_Orbitals* ParaV;
8991

92+
//temporary set it to public for ElecStateLCAO class, would be refactor later
93+
void cal_dk_k(const Grid_Technique &gt);
94+
9095
private:
9196

9297
// whether the DM array has been allocated
@@ -96,8 +101,6 @@ class Local_Orbital_Charge
96101

97102
void cal_dk_gamma(void);
98103

99-
void cal_dk_k(const Grid_Technique &gt);
100-
101104
// mohan add 2010-09-06
102105
int lgd_last;// sub-FFT-mesh orbitals number in previous step.
103106
int lgd_now;// sub-FFT-mesh orbitals number in this step.

source/src_lcao/local_orbital_wfc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ int Local_Orbital_wfc::localIndex(int globalindex, int nblk, int nprocs, int& my
147147
}
148148

149149
#ifdef __MPI
150-
void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao, double* wfc_2d, double** wfc_grid)
150+
void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao, const double* wfc_2d, double** wfc_grid)
151151
{
152152
ModuleBase::TITLE(" Local_Orbital_wfc", "wfc_2d_to_grid");
153153
ModuleBase::timer::tick(" Local_Orbital_wfc","wfc_2d_to_grid");
@@ -222,7 +222,7 @@ void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao, double* wfc_2d, double*
222222

223223
void Local_Orbital_wfc::wfc_2d_to_grid(
224224
int out_wfc_lcao,
225-
std::complex<double>* wfc_2d,
225+
const std::complex<double>* wfc_2d,
226226
std::complex<double>** wfc_grid,
227227
int ik)
228228
{

source/src_lcao/local_orbital_wfc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class Local_Orbital_wfc
6666
// (in which the implementation should be put in header file )
6767
// because sub-function `write_lowf_complex`contains GlobalC declared in `global.h`
6868
// which will cause lots of "not defined" if included in a header file.
69-
void wfc_2d_to_grid(int out_wfc_lcao, double* wfc_2d, double** wfc_grid);
70-
void wfc_2d_to_grid(int out_wfc_lcao, std::complex<double>* wfc_2d, std::complex<double>** wfc_grid, int ik);
69+
void wfc_2d_to_grid(int out_wfc_lcao, const double* wfc_2d, double** wfc_grid);
70+
void wfc_2d_to_grid(int out_wfc_lcao, const std::complex<double>* wfc_2d, std::complex<double>** wfc_grid, int ik);
7171
#endif
7272

7373
private:

source/src_pdiag/test/diago_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ Charge::~Charge(){};
3838
namespace GlobalC {Charge_Broyden CHR;};
3939
Local_Orbital_wfc::Local_Orbital_wfc(){};
4040
Local_Orbital_wfc::~Local_Orbital_wfc(){};
41-
void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao, double *wfc_2d, double **wfc_grid){};
41+
void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao, const double *wfc_2d, double **wfc_grid){};
4242
void Local_Orbital_wfc::wfc_2d_to_grid(int out_wfc_lcao,
43-
std::complex<double> *wfc_2d,
43+
const std::complex<double> *wfc_2d,
4444
std::complex<double> **wfc_grid,
4545
int ik){};
4646
Occupy::Occupy(){};

0 commit comments

Comments
 (0)