Skip to content

Commit e49f0b7

Browse files
authored
Merge pull request #1002 from deepmodeling/HSolver
Refactor: add HSolver module in LCAO
2 parents 51bdbe3 + 0b250ad commit e49f0b7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1151
-2097
lines changed

source/Makefile.Objects

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,7 @@ ORB_table_alpha.o\
136136
ORB_gen_tables.o\
137137
local_orbital_wfc.o\
138138
local_orbital_charge.o\
139-
ELEC_cbands_k.o\
140-
ELEC_cbands_gamma.o\
141139
ELEC_evolve.o\
142-
ELEC_scf.o\
143-
ELEC_nscf.o\
144140
LOOP_cell.o\
145141
LOOP_ions.o\
146142
run_md_lcao.o\

source/input_conv.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
#include "module_base/timer.h"
3030
#include "module_surchem/efield.h"
3131

32+
#include "module_elecstate/elecstate_lcao.h"
33+
#include "module_hsolver/hsolver_lcao.h"
34+
3235
void Input_Conv::Convert(void)
3336
{
3437
ModuleBase::TITLE("Input_Conv", "Convert");
@@ -436,9 +439,9 @@ void Input_Conv::Convert(void)
436439
GlobalC::en.out_proj_band = INPUT.out_proj_band;
437440
#ifdef __LCAO
438441
Local_Orbital_Charge::out_dm = INPUT.out_dm;
439-
Pdiag_Double::out_mat_hs = INPUT.out_mat_hs;
440-
Pdiag_Double::out_mat_hsR = INPUT.out_mat_hs2; // LiuXh add 2019-07-16
441-
Pdiag_Double::out_wfc_lcao = INPUT.out_wfc_lcao;
442+
hsolver::HSolverLCAO::out_mat_hs = INPUT.out_mat_hs;
443+
hsolver::HSolverLCAO::out_mat_hsR = INPUT.out_mat_hs2; // LiuXh add 2019-07-16
444+
elecstate::ElecStateLCAO::out_wfc_lcao = INPUT.out_wfc_lcao;
442445
#endif
443446

444447
GlobalC::en.dos_emin_ev = INPUT.dos_emin_ev;

source/module_elecstate/cal_dm.h

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#ifndef CAL_DM_H
2+
#define CAL_DM_H
3+
4+
#include "math_tools.h"
5+
#include "module_base/matrix.h"
6+
#include "module_base/complexmatrix.h"
7+
#include "src_lcao/local_orbital_charge.h"
8+
9+
namespace elecstate
10+
{
11+
12+
// for gamma_only(double case) and multi-k(complex<double> case)
13+
inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, const psi::Psi<double>& wfc, std::vector<ModuleBase::matrix>& dm)
14+
{
15+
ModuleBase::TITLE("elecstate", "cal_dm");
16+
17+
//dm.resize(wfc.get_nk(), ParaV->ncol, ParaV->nrow);
18+
const int nbands_local = wfc.get_nbands();
19+
const int nbasis_local = wfc.get_nbasis();
20+
21+
// dm = wfc.T * wg * wfc.conj()
22+
// dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
23+
for (int ik = 0; ik < wfc.get_nk(); ++ik)
24+
{
25+
wfc.fix_k(ik);
26+
//dm.fix_k(ik);
27+
dm[ik].create(ParaV->ncol, ParaV->nrow);
28+
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
29+
psi::Psi<double> wg_wfc(wfc, 1);
30+
31+
int ib_global = 0;
32+
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
33+
{
34+
while (ib_local != ParaV->trace_loc_col[ib_global])
35+
{
36+
++ib_global;
37+
if (ib_global >= wg.nc)
38+
{
39+
break;
40+
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check trace_loc_col!");
41+
}
42+
}
43+
if (ib_global >= wg.nc) continue;
44+
const double wg_local = wg(ik, ib_global);
45+
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
46+
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
47+
}
48+
49+
// C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
50+
#ifdef __MPI
51+
psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
52+
#else
53+
psiMulPsi(wg_wfc, wfc, dm[ik]);
54+
#endif
55+
}
56+
57+
return;
58+
}
59+
60+
inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, const psi::Psi<std::complex<double>>& wfc, std::vector<ModuleBase::ComplexMatrix>& dm)
61+
{
62+
ModuleBase::TITLE("elecstate", "cal_dm");
63+
64+
//dm.resize(wfc.get_nk(), ParaV->ncol, ParaV->nrow);
65+
const int nbands_local = wfc.get_nbands();
66+
const int nbasis_local = wfc.get_nbasis();
67+
68+
// dm = wfc.T * wg * wfc.conj()
69+
// dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
70+
for (int ik = 0; ik < wfc.get_nk(); ++ik)
71+
{
72+
wfc.fix_k(ik);
73+
//dm.fix_k(ik);
74+
dm[ik].create(ParaV->ncol, ParaV->nrow);
75+
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
76+
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr);
77+
const std::complex<double>* pwfc = wfc.get_pointer();
78+
std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
79+
for(int i = 0;i<wg_wfc.size();++i)
80+
{
81+
pwg_wfc[i] = conj(pwfc[i]);
82+
}
83+
84+
int ib_global = 0;
85+
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
86+
{
87+
while (ib_local != ParaV->trace_loc_col[ib_global])
88+
{
89+
++ib_global;
90+
if (ib_global >= wg.nc)
91+
{
92+
break;
93+
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check trace_loc_col!");
94+
}
95+
}
96+
if (ib_global >= wg.nc) continue;
97+
const double wg_local = wg(ik, ib_global);
98+
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
99+
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
100+
}
101+
102+
// C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
103+
#ifdef __MPI
104+
psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
105+
#else
106+
psiMulPsi(wg_wfc, wfc, dm[ik]);
107+
#endif
108+
}
109+
110+
return;
111+
}
112+
113+
}//namespace elecstate
114+
115+
#endif

source/module_elecstate/elecstate.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ void ElecState::calculate_weights()
2020
{
2121
ModuleBase::TITLE("ElecState", "calculate_weights");
2222

23+
if (GlobalV::ocp == 1)
24+
{
25+
for (int ik = 0; ik < GlobalC::kv.nks; ik++)
26+
{
27+
for (int ib = 0; ib < GlobalV::NBANDS; ib++)
28+
{
29+
this->wg(ik, ib) = GlobalV::ocp_kb[ik * GlobalV::NBANDS + ib];
30+
}
31+
}
32+
return;
33+
}
34+
2335
// for test
2436
// std::cout << " gaussian_broadening = " << use_gaussian_broadening << std::endl;
2537
// std::cout << " tetrahedron_method = " << use_tetrahedron_method << std::endl;
@@ -191,7 +203,7 @@ void ElecState::calculate_weights()
191203

192204
void ElecState::calEBand()
193205
{
194-
ModuleBase::TITLE("ElecStatePW", "calEBand");
206+
ModuleBase::TITLE("ElecState", "calEBand");
195207
//calculate ebands using wg and ekb
196208
this->eband = 0.0;
197209
for (int ik = 0; ik < this->ekb.nr; ++ik)

source/module_elecstate/elecstate_lcao.cpp

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,13 @@
11
#include "elecstate_lcao.h"
22

3-
#include "math_tools.h"
3+
#include "cal_dm.h"
44
#include "module_base/timer.h"
55
#include "module_gint/grid_technique.h"
66

77
namespace elecstate
88
{
99
int ElecStateLCAO::out_wfc_lcao = 0;
1010

11-
// for gamma_only(double case) and multi-k(complex<double> case)
12-
template <typename T> void ElecStateLCAO::cal_dm(const ModuleBase::matrix& wg, const psi::Psi<T>& wfc, psi::Psi<T>& dm)
13-
{
14-
ModuleBase::TITLE("ElecStateLCAO", "cal_dm");
15-
16-
dm.resize(wfc.get_nk(), this->loc->ParaV->ncol, this->loc->ParaV->nrow);
17-
const int nbands_local = wfc.get_nbands();
18-
const int nbasis_local = wfc.get_nbasis();
19-
20-
// dm = wfc.T * wg * wfc.conj()
21-
// dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
22-
for (int ik = 0; ik < wfc.get_nk(); ++ik)
23-
{
24-
wfc.fix_k(ik);
25-
dm.fix_k(ik);
26-
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
27-
psi::Psi<T> wg_wfc(wfc, 1);
28-
29-
int ib_global = 0;
30-
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
31-
{
32-
while (ib_local != this->loc->ParaV->trace_loc_col[ib_global])
33-
{
34-
++ib_global;
35-
if (ib_global >= wg.nc)
36-
{
37-
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check trace_loc_col!");
38-
}
39-
}
40-
const double wg_local = wg(ik, ib_global);
41-
T* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
42-
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
43-
}
44-
45-
// C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
46-
#ifdef __MPI
47-
psiMulPsiMpi(wg_wfc, wfc, dm, this->loc->ParaV->desc_wfc, this->loc->ParaV->desc);
48-
#else
49-
psiMulPsi(wg_wfc, wfc, dm);
50-
#endif
51-
}
52-
53-
return;
54-
}
55-
5611
// multi-k case
5712
void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
5813
{
@@ -65,12 +20,12 @@ void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
6520
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
6621

6722
// this part for calculating dm_k in 2d-block format, not used for charge now
68-
psi::Psi<std::complex<double>> dm_k_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
23+
// psi::Psi<std::complex<double>> dm_k_2d();
6924

7025
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
7126
|| GlobalV::KS_SOLVER == "lapack") // Peize Lin test 2019-05-15
7227
{
73-
this->cal_dm(this->wg, psi, dm_k_2d);
28+
cal_dm(this->loc->ParaV, this->wg, psi, this->loc->dm_k);
7429
}
7530

7631
// this part for steps:
@@ -82,11 +37,29 @@ void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
8237
for (int ik = 0; ik < psi.get_nk(); ik++)
8338
{
8439
psi.fix_k(ik);
85-
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), this->lowf->wfc_k_grid[ik], ik);
40+
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), this->lowf->wfc_k_grid[ik], ik, this->ekb, this->wg);
41+
//added by zhengdy-soc, rearrange the wfc_k_grid from [up,down,up,down...] to [up,up...down,down...],
42+
if(GlobalV::NSPIN==4)
43+
{
44+
int row = GlobalC::GridT.lgd;
45+
std::vector<std::complex<double>> tmp(row);
46+
for(int ib=0; ib<GlobalV::NBANDS; ib++)
47+
{
48+
for(int iw=0; iw<row / GlobalV::NPOL; iw++)
49+
{
50+
tmp[iw] = this->lowf->wfc_k_grid[ik][ib][iw * GlobalV::NPOL];
51+
tmp[iw + row / GlobalV::NPOL] = this->lowf->wfc_k_grid[ik][ib][iw * GlobalV::NPOL + 1];
52+
}
53+
for(int iw=0; iw<row; iw++)
54+
{
55+
this->lowf->wfc_k_grid[ik][ib][iw] = tmp[iw];
56+
}
57+
}
58+
}
8659
}
8760
}
8861

89-
this->loc->cal_dk_k(GlobalC::GridT);
62+
this->loc->cal_dk_k(GlobalC::GridT, this->wg);
9063
for (int is = 0; is < GlobalV::NSPIN; is++)
9164
{
9265
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
@@ -119,9 +92,9 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
11992
{
12093
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
12194

122-
psi::Psi<double> dm_gamma_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
95+
//psi::Psi<double> dm_gamma_2d;
12396
// caution:wfc and dm
124-
this->cal_dm(this->wg, psi, dm_gamma_2d);
97+
cal_dm(this->loc->ParaV, this->wg, psi, this->loc->dm_gamma);
12598

12699
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
127100

@@ -132,9 +105,10 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
132105
{
133106
psi.fix_k(ik);
134107
double** wfc_grid = nullptr; // output but not do "2d-to-grid" conversion
135-
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), wfc_grid);
108+
this->lowf->wfc_2d_to_grid(ElecStateLCAO::out_wfc_lcao, psi.get_pointer(), wfc_grid, this->ekb, this->wg);
136109
}
137-
this->loc->dm2dToGrid(dm_gamma_2d, this->loc->DM[ik]); // transform dm_gamma[is].c to this->loc->DM[is]
110+
//this->loc->dm2dToGrid(this->loc->dm_gamma[ik], this->loc->DM[ik]); // transform dm_gamma[is].c to this->loc->DM[is]
111+
this->loc->cal_dk_gamma_from_2D_pub();
138112
}
139113
}
140114

source/module_elecstate/elecstate_lcao.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ class ElecStateLCAO : public ElecState
4848
// calcualte rho for each k
4949
// void rhoBandK(const psi::Psi<std::complex<double>>& psi);
5050

51-
template <typename T> void cal_dm(const ModuleBase::matrix& wg, const psi::Psi<T>& wfc, psi::Psi<T>& dm);
52-
5351
Local_Orbital_Charge* loc = nullptr;
5452
LCAO_Hamilt* uhm = nullptr;
5553
Local_Orbital_wfc* lowf = nullptr;

source/module_elecstate/elecstate_pw.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
#include "module_base/constants.h"
44
#include "src_parallel/parallel_reduce.h"
55
#include "src_pw/global.h"
6+
#include "module_base/timer.h"
67

78
namespace elecstate
89
{
910

1011
void ElecStatePW::psiToRho(const psi::Psi<std::complex<double>>& psi)
1112
{
13+
ModuleBase::TITLE("ElecStatePW", "psiToRho");
14+
ModuleBase::timer::tick("ElecStatePW", "psiToRho");
1215
this->calculate_weights();
1316

1417
this->calEBand();
@@ -28,6 +31,7 @@ void ElecStatePW::psiToRho(const psi::Psi<std::complex<double>>& psi)
2831
this->updateRhoK(psi);
2932
}
3033
this->parallelK();
34+
ModuleBase::timer::tick("ElecStatePW", "psiToRho");
3135
}
3236

3337
void ElecStatePW::updateRhoK(const psi::Psi<std::complex<double>>& psi)

0 commit comments

Comments
 (0)