Skip to content

Commit 8ef5b8f

Browse files
committed
Refactor: added cal_dm in ElecStateLCAO, psi_k_2d->dm_k_2d
1 parent 683b69f commit 8ef5b8f

File tree

11 files changed

+304
-13
lines changed

11 files changed

+304
-13
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_library(
22
elecstate
33
OBJECT
4-
elecstate_pw.cpp
54
elecstate.cpp
5+
elecstate_pw.cpp
6+
elecstate_lcao.cpp
67
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include "elecstate_lcao.h"
2+
#include "math_tools.h"
3+
#include "module_base/timer.h"
4+
#include "src_lcao/grid_technique.h"
5+
6+
namespace elecstate
7+
{
8+
9+
//for gamma_only(double case) and multi-k(complex<double> case)
10+
template<typename T> void ElecStateLCAO::cal_dm(const ModuleBase::matrix& wg,
11+
const psi::Psi<T>& wfc,
12+
psi::Psi<T>& dm)
13+
{
14+
ModuleBase::TITLE("ElecStateLCAO", "cal_dm");
15+
16+
dm.resize( wfc.get_nk(), this->loc->ParaV->ncol, this->loc->ParaV->nrow );
17+
const int nbands_local = wfc.get_nbands();
18+
const int nbasis_local = wfc.get_nbasis();
19+
20+
// dm = wfc.T * wg * wfc.conj()
21+
// dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
22+
for(int ik=0; ik<wfc.get_nk(); ++ik)
23+
{
24+
wfc.fix_k(ik);
25+
dm.fix_k(ik);
26+
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
27+
psi::Psi<T> wg_wfc( wfc, 1 );
28+
29+
int ib_global = 0;
30+
for(int ib_local=0; ib_local<nbands_local; ++ib_local)
31+
{
32+
while(ib_local != this->loc->ParaV->trace_loc_col[ib_global])
33+
{
34+
++ib_global;
35+
if(ib_global>=wg.nc)
36+
{
37+
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check trace_loc_col!");
38+
}
39+
}
40+
const double wg_local = wg(ik, ib_global);
41+
T* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
42+
BlasConnector::scal( nbasis_local, wg_local, wg_wfc_pointer, 1 );
43+
}
44+
45+
// C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
46+
#ifdef __MPI
47+
psiMulPsiMpi(wg_wfc, wfc, dm, this->loc->ParaV->desc_wfc, this->loc->ParaV->desc);
48+
#else
49+
psiMulPsi(wg_wfc, wfc, dm);
50+
#endif
51+
}
52+
53+
return;
54+
}
55+
56+
void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
57+
{
58+
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
59+
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
60+
61+
psi::Psi<std::complex<double>> dm_k_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
62+
63+
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");
64+
//this->loc->cal_dk_k(GlobalC::GridT);
65+
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
66+
|| GlobalV::KS_SOLVER == "lapack") // Peize Lin test 2019-05-15
67+
{
68+
this->cal_dm(this->wg, psi, dm_k_2d);
69+
}
70+
71+
for (int is = 0; is < GlobalV::NSPIN; is++)
72+
{
73+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
74+
}
75+
76+
//------------------------------------------------------------
77+
// calculate the charge density on real space grid.
78+
//------------------------------------------------------------
79+
80+
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
81+
//uhm.GK.cal_rho_k(this->loc->DM_R);
82+
83+
this->charge->renormalize_rho();
84+
85+
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
86+
return;
87+
}
88+
89+
// Gamma_only case
90+
void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
91+
{
92+
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
93+
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
94+
95+
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "scalapack_gvx"
96+
|| GlobalV::KS_SOLVER == "lapack")
97+
{
98+
// LiuXh modify 2021-09-06, clear memory, cal_dk_gamma() not used for genelpa solver.
99+
// density matrix has already been calculated.
100+
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
101+
102+
psi::Psi<double> dm_gamma_2d(psi.get_nk(), psi.get_nbasis(), psi.get_nbasis());
103+
// caution:wfc and dm
104+
this->cal_dm(this->wg, psi, dm_gamma_2d);
105+
106+
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
107+
108+
//this->loc->cal_dk_gamma_from_2D(); // transform dm_gamma[is].c to this->loc->DM[is]
109+
}
110+
111+
for (int is = 0; is < GlobalV::NSPIN; is++)
112+
{
113+
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
114+
}
115+
116+
//------------------------------------------------------------
117+
// calculate the charge density on real space grid.
118+
//------------------------------------------------------------
119+
ModuleBase::GlobalFunc::NOTE("Calculate the charge on real space grid!");
120+
//uhm.GG.cal_rho(this->loc->DM);
121+
122+
this->charge->renormalize_rho();
123+
124+
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");
125+
return;
126+
}
127+
128+
} // namespace elecstate
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef ELECSTATELCAO_H
2+
#define ELECSTATELCAO_H
3+
4+
#include "elecstate.h"
5+
#include "src_lcao/local_orbital_charge.h"
6+
7+
namespace elecstate
8+
{
9+
10+
class ElecStateLCAO : public ElecState
11+
{
12+
public:
13+
ElecStateLCAO(Charge* chg_in, int nks_in, int nbands_in){init(chg_in, nks_in, nbands_in);}
14+
//void init(Charge* chg_in):charge(chg_in){} override;
15+
16+
// interface for HSolver to calculate rho from Psi
17+
virtual void psiToRho(const psi::Psi<std::complex<double>> &psi) override;
18+
virtual void psiToRho(const psi::Psi<double> &psi) override;
19+
// return current electronic density rho, as a input for constructing Hamiltonian
20+
//const double* getRho(int spin) const override;
21+
22+
// update charge density for next scf step
23+
//void getNewRho() override;
24+
25+
private:
26+
27+
// calculate electronic charge density on grid points or density matrix in real space
28+
// the consequence charge density rho saved into rho_out, preparing for charge mixing.
29+
void updateRhoK(const psi::Psi<std::complex<double>>& psi) ;//override;
30+
//sum over all pools for rho and ebands
31+
void parallelK();
32+
// calcualte rho for each k
33+
void rhoBandK(const psi::Psi<std::complex<double>>& psi);
34+
35+
template<typename T> void cal_dm(
36+
const ModuleBase::matrix& wg,
37+
const psi::Psi<T>& wfc,
38+
psi::Psi<T>& dm);
39+
40+
Local_Orbital_Charge* loc = nullptr;
41+
};
42+
43+
} // namespace elecstate
44+
45+
#endif
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include "module_psi/psi.h"
2+
#include "module_base/scalapack_connector.h"
3+
#include "module_base/blas_connector.h"
4+
5+
#ifdef __MPI
6+
void psiMulPsiMpi(
7+
const psi::Psi<double>& psi1,
8+
const psi::Psi<double>& psi2,
9+
psi::Psi<double>& dm_out,
10+
const int* desc_psi,
11+
const int* desc_dm)
12+
{
13+
const double one_float=1.0, zero_float=0.0;
14+
const int one_int=1;
15+
const char N_char='N', T_char='T';
16+
const int nlocal = desc_dm[2];
17+
const int nbands = desc_psi[3];
18+
pdgemm_(
19+
&N_char, &T_char,
20+
&nlocal, &nlocal, &nbands,
21+
&one_float,
22+
psi1.get_pointer(), &one_int, &one_int, desc_psi,
23+
psi2.get_pointer(), &one_int, &one_int, desc_psi,
24+
&zero_float,
25+
dm_out.get_pointer(), &one_int, &one_int, desc_dm);
26+
}
27+
28+
void psiMulPsiMpi(
29+
const psi::Psi<std::complex<double>>& psi1,
30+
const psi::Psi<std::complex<double>>& psi2,
31+
psi::Psi<std::complex<double>>& dm_out,
32+
const int* desc_psi,
33+
const int* desc_dm)
34+
{
35+
const complex<double> one_complex={1.0,0.0}, zero_complex={0.0,0.0};
36+
const int one_int=1;
37+
const char N_char='N', T_char='T';
38+
const int nlocal = desc_dm[2];
39+
const int nbands = desc_psi[3];
40+
pzgemm_(
41+
&N_char, &T_char,
42+
&nlocal, &nlocal, &nbands,
43+
(const double*)(&one_complex),
44+
psi1.get_pointer(), &one_int, &one_int, desc_psi,
45+
psi2.get_pointer(), &one_int, &one_int, desc_psi,
46+
(const double*)(&zero_complex),
47+
dm_out.get_pointer(), &one_int, &one_int, desc_dm);
48+
}
49+
50+
#else
51+
void psiMulPsi(psi::Psi<double>& psi1, psi::Psi<double>& psi2, psi::Psi<double>& dm_out)
52+
{
53+
const double one_float=1.0, zero_float=0.0;
54+
const int one_int=1;
55+
const char N_char='N', T_char='T';
56+
const int nlocal = psi1.get_nbasis();
57+
const int nbands = psi1.get_nbands();
58+
dgemm_(
59+
&N_char, &T_char,
60+
&nlocal, &nlocal, &nbands,
61+
&one_float,
62+
psi1.get_pointer(), &nlocal,
63+
psi2.get_pointer(), &nlocal,
64+
&zero_float,
65+
dm_out.get_pointer(), &nlocal);
66+
}
67+
68+
void psiMulPsi(psi::Psi<double>& psi1, psi::Psi<double>& psi2, psi::Psi<double>& dm_out)
69+
{
70+
const int one_int=1;
71+
const char N_char='N', T_char='T';
72+
const int nlocal = psi1.get_nbasis();
73+
const int nbands = psi1.get_nbands();
74+
const complex<double> one_complex={1.0,0.0}, zero_complex={0.0,0.0};
75+
zgemm_(
76+
&N_char, &T_char,
77+
&nlocal, &nlocal, &nbands,
78+
&one_complex,
79+
psi1.get_pointer(), &nlocal,
80+
psi2.get_pointer(), &nlocal,
81+
&zero_complex,
82+
dm_out.get_pointer(), &nlocal);
83+
}
84+
85+
#endif

source/module_hsolver/test/diago_mock.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class HPsi
7878
{
7979
PW_Basis* pbas;
8080
int* ngk = nullptr;
81-
psi::Psi<std::complex<double>> psitmp(ngk,1,nband,npw);
81+
psi::Psi<std::complex<double>> psitmp(1,nband,npw,ngk);
8282
for(int i=0;i<nband;i++)
8383
{
8484
for(int j=0;j<npw;j++) psitmp(0,i,j) = psimatrix(i,j);

source/module_psi/psi.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,27 @@ class Psi
2525
this->resize(pbasis_in->Klist->nks, GlobalV::NBANDS, pbasis_in->ngmw);
2626
}
2727
Psi(const int* ngk_in){this->ngk = ngk_in;}
28-
Psi(const int* ngk_in, int nk_in, int nbd_in, int nbs_in)
28+
Psi(int nk_in, int nbd_in, int nbs_in, const int* ngk_in=nullptr)
2929
{
3030
this->ngk = ngk_in;
3131
this->resize(nk_in, nbd_in, nbs_in);
3232
this->current_b = 0;
3333
this->current_k = 0;
3434
}
35-
Psi(const Psi& psi_in, const int& nk_in, const int& nband_in)
35+
Psi(const Psi& psi_in, const int& nk_in)
3636
{
37-
assert(nk_in<=psi_in.get_nk() && nband_in<=psi_in.get_nbands());
38-
this->resize(nk_in, nband_in, psi_in.get_nbasis());
37+
assert(nk_in<=psi_in.get_nk());
38+
this->resize(nk_in, psi_in.get_nbands(), psi_in.get_nbasis());
3939
//if size of k is 1, copy from Psi in current_k,
4040
//else copy from start of Psi
41+
const T* tmp = psi_in.get_pointer();
4142
if(nk_in==1) for(size_t index=0; index<this->size();++index)
4243
{
43-
psi[index] = psi_in.get_pointer()[index];
44+
psi[index] = tmp[index];
4445
//current_k for this Psi only keep the spin index same as the copied Psi
4546
this->current_k = psi_in.get_current_k();
4647
}
47-
else for(size_t index=0; index<this->size();++index) psi[index] = psi_in.get_pointer()[index];
48+
else for(size_t index=0; index<this->size();++index) psi[index] = tmp[index];
4849
}
4950
// initialize the wavefunction coefficient
5051
// only resize and construct function now is used

source/src_lcao/LCAO_nnr.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#endif
88

99
// This is for cell R dependent part.
10-
void Grid_Technique::cal_nnrg()
10+
void Grid_Technique::cal_nnrg(Parallel_Orbitals* pv)
1111
{
1212
ModuleBase::TITLE("LCAO_nnr","cal_nnrg");
1313

@@ -18,6 +18,7 @@ void Grid_Technique::cal_nnrg()
1818
delete[] nlocdimg;
1919
delete[] nlocstartg;
2020
delete[] nad;
21+
this->nnrg_index.resize(0);
2122

2223
this->nad = new int[GlobalC::ucell.nat];
2324
this->nlocdimg = new int[GlobalC::ucell.nat];
@@ -81,7 +82,12 @@ void Grid_Technique::cal_nnrg()
8182
// GlobalC::ORB.Phi[it].getRcut = 7.0000000000000008
8283
if(distance < rcut - 1.0e-15)
8384
{
84-
const int nelement = atom1->nw * atom2->nw;//modified by zhengdy-soc, no need to double
85+
//storing the indexed for nnrg
86+
const int mu = pv->trace_loc_row[iat];
87+
const int nu = pv->trace_loc_col[iat2];
88+
this->nnrg_index.push_back(gridIntegral::gridIndex{this->nnrg, mu, nu, GlobalC::GridD.getBox(ad), atom1->nw, atom2->nw});
89+
90+
const int nelement = atom1->nw * atom2->nw;
8591
this->nnrg += nelement;
8692
this->nlocdimg[iat] += nelement;
8793
this->nad[iat]++;

source/src_lcao/LOOP_elec.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ void LOOP_elec::set_matrix_grid(Record_adj &ra)
9292

9393
// need to first calculae lgd.
9494
// using GlobalC::GridT.init.
95-
GlobalC::GridT.cal_nnrg();
95+
GlobalC::GridT.cal_nnrg(pv);
9696
}
9797

9898
ModuleBase::timer::tick("LOOP_elec","set_matrix_grid");

source/src_lcao/LOOP_ions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ void LOOP_ions::final_scf(void)
563563

564564
// need to first calculae lgd.
565565
// using GlobalC::GridT.init.
566-
GlobalC::GridT.cal_nnrg();
566+
GlobalC::GridT.cal_nnrg(pv);
567567
}
568568
//------------------------------------------------------------------
569569

source/src_lcao/grid_index.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "module_base/vector3.h"
2+
/// index structure for grid integral module
3+
/// in ABACUS, this index is stored for tracing:
4+
/// 1. starting row and column index (mu, nu)
5+
/// 2. R distance from atom 1 and atom2 (dR)
6+
/// 3. number of orbitals for atom1 and atom2 (nw1, nw2)
7+
namespace gridIntegral
8+
{
9+
10+
struct gridIndex
11+
{
12+
int nnrg;
13+
int mu;
14+
int nu;
15+
ModuleBase::Vector3<int> dR;
16+
int nw1;
17+
int nw2;
18+
};
19+
20+
}

0 commit comments

Comments
 (0)