11#include " elecstate_lcao.h"
22
3- #include " math_tools .h"
3+ #include " cal_dm .h"
44#include " module_base/timer.h"
55#include " module_gint/grid_technique.h"
66
77namespace elecstate
88{
99int ElecStateLCAO::out_wfc_lcao = 0 ;
1010
11- // for gamma_only(double case) and multi-k(complex<double> case)
12- template <typename T> void ElecStateLCAO::cal_dm (const ModuleBase::matrix& wg, const psi::Psi<T>& wfc, psi::Psi<T>& dm)
13- {
14- ModuleBase::TITLE (" ElecStateLCAO" , " cal_dm" );
15-
16- dm.resize (wfc.get_nk (), this ->loc ->ParaV ->ncol , this ->loc ->ParaV ->nrow );
17- const int nbands_local = wfc.get_nbands ();
18- const int nbasis_local = wfc.get_nbasis ();
19-
20- // dm = wfc.T * wg * wfc.conj()
21- // dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
22- for (int ik = 0 ; ik < wfc.get_nk (); ++ik)
23- {
24- wfc.fix_k (ik);
25- dm.fix_k (ik);
26- // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
27- psi::Psi<T> wg_wfc (wfc, 1 );
28-
29- int ib_global = 0 ;
30- for (int ib_local = 0 ; ib_local < nbands_local; ++ib_local)
31- {
32- while (ib_local != this ->loc ->ParaV ->trace_loc_col [ib_global])
33- {
34- ++ib_global;
35- if (ib_global >= wg.nc )
36- {
37- ModuleBase::WARNING_QUIT (" ElecStateLCAO::cal_dm" , " please check trace_loc_col!" );
38- }
39- }
40- const double wg_local = wg (ik, ib_global);
41- T* wg_wfc_pointer = &(wg_wfc (0 , ib_local, 0 ));
42- BlasConnector::scal (nbasis_local, wg_local, wg_wfc_pointer, 1 );
43- }
44-
45- // C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
46- #ifdef __MPI
47- psiMulPsiMpi (wg_wfc, wfc, dm, this ->loc ->ParaV ->desc_wfc , this ->loc ->ParaV ->desc );
48- #else
49- psiMulPsi (wg_wfc, wfc, dm);
50- #endif
51- }
52-
53- return ;
54- }
55-
5611// multi-k case
5712void ElecStateLCAO::psiToRho (const psi::Psi<std::complex <double >>& psi)
5813{
@@ -65,12 +20,12 @@ void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
6520 ModuleBase::GlobalFunc::NOTE (" Calculate the density matrix." );
6621
6722 // this part for calculating dm_k in 2d-block format, not used for charge now
68- psi::Psi<std::complex <double >> dm_k_2d (psi. get_nk (), psi. get_nbasis (), psi. get_nbasis () );
23+ // psi::Psi<std::complex<double>> dm_k_2d();
6924
7025 if (GlobalV::KS_SOLVER == " genelpa" || GlobalV::KS_SOLVER == " scalapack_gvx"
7126 || GlobalV::KS_SOLVER == " lapack" ) // Peize Lin test 2019-05-15
7227 {
73- this ->cal_dm ( this ->wg , psi, dm_k_2d );
28+ cal_dm ( this ->loc -> ParaV , this ->wg , psi, this -> loc -> dm_k );
7429 }
7530
7631 // this part for steps:
@@ -82,11 +37,29 @@ void ElecStateLCAO::psiToRho(const psi::Psi<std::complex<double>>& psi)
8237 for (int ik = 0 ; ik < psi.get_nk (); ik++)
8338 {
8439 psi.fix_k (ik);
85- this ->lowf ->wfc_2d_to_grid (ElecStateLCAO::out_wfc_lcao, psi.get_pointer (), this ->lowf ->wfc_k_grid [ik], ik);
40+ this ->lowf ->wfc_2d_to_grid (ElecStateLCAO::out_wfc_lcao, psi.get_pointer (), this ->lowf ->wfc_k_grid [ik], ik, this ->ekb , this ->wg );
41+ // added by zhengdy-soc, rearrange the wfc_k_grid from [up,down,up,down...] to [up,up...down,down...],
42+ if (GlobalV::NSPIN==4 )
43+ {
44+ int row = GlobalC::GridT.lgd ;
45+ std::vector<std::complex <double >> tmp (row);
46+ for (int ib=0 ; ib<GlobalV::NBANDS; ib++)
47+ {
48+ for (int iw=0 ; iw<row / GlobalV::NPOL; iw++)
49+ {
50+ tmp[iw] = this ->lowf ->wfc_k_grid [ik][ib][iw * GlobalV::NPOL];
51+ tmp[iw + row / GlobalV::NPOL] = this ->lowf ->wfc_k_grid [ik][ib][iw * GlobalV::NPOL + 1 ];
52+ }
53+ for (int iw=0 ; iw<row; iw++)
54+ {
55+ this ->lowf ->wfc_k_grid [ik][ib][iw] = tmp[iw];
56+ }
57+ }
58+ }
8659 }
8760 }
8861
89- this ->loc ->cal_dk_k (GlobalC::GridT);
62+ this ->loc ->cal_dk_k (GlobalC::GridT, this -> wg );
9063 for (int is = 0 ; is < GlobalV::NSPIN; is++)
9164 {
9265 ModuleBase::GlobalFunc::ZEROS (this ->charge ->rho [is], this ->charge ->nrxx ); // mohan 2009-11-10
@@ -119,9 +92,9 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
11992 {
12093 ModuleBase::timer::tick (" ElecStateLCAO" , " cal_dm_2d" );
12194
122- psi::Psi<double > dm_gamma_2d (psi. get_nk (), psi. get_nbasis (), psi. get_nbasis ()) ;
95+ // psi::Psi<double> dm_gamma_2d;
12396 // caution:wfc and dm
124- this ->cal_dm ( this ->wg , psi, dm_gamma_2d );
97+ cal_dm ( this ->loc -> ParaV , this ->wg , psi, this -> loc -> dm_gamma );
12598
12699 ModuleBase::timer::tick (" ElecStateLCAO" , " cal_dm_2d" );
127100
@@ -132,9 +105,10 @@ void ElecStateLCAO::psiToRho(const psi::Psi<double>& psi)
132105 {
133106 psi.fix_k (ik);
134107 double ** wfc_grid = nullptr ; // output but not do "2d-to-grid" conversion
135- this ->lowf ->wfc_2d_to_grid (ElecStateLCAO::out_wfc_lcao, psi.get_pointer (), wfc_grid);
108+ this ->lowf ->wfc_2d_to_grid (ElecStateLCAO::out_wfc_lcao, psi.get_pointer (), wfc_grid, this -> ekb , this -> wg );
136109 }
137- this ->loc ->dm2dToGrid (dm_gamma_2d, this ->loc ->DM [ik]); // transform dm_gamma[is].c to this->loc->DM[is]
110+ // this->loc->dm2dToGrid(this->loc->dm_gamma[ik], this->loc->DM[ik]); // transform dm_gamma[is].c to this->loc->DM[is]
111+ this ->loc ->cal_dk_gamma_from_2D_pub ();
138112 }
139113 }
140114
0 commit comments