11#include " unk_overlap_pw.h"
22
3- #include " module_parameter/parameter.h"
43#include " module_hamilt_pw/hamilt_pwdft/global.h"
4+ #include " module_parameter/parameter.h"
55
66unkOverlap_pw::unkOverlap_pw ()
77{
8- // GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl;
8+ // GlobalV::ofs_running << "this is unkOverlap_pw()" << std::endl;
99}
1010
1111unkOverlap_pw::~unkOverlap_pw ()
1212{
13- // GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl;
13+ // GlobalV::ofs_running << "this is ~unkOverlap_pw()" << std::endl;
1414}
1515
1616std::complex <double > unkOverlap_pw::unkdotp_G (const ModulePW::PW_Basis_K* wfcpw,
@@ -20,50 +20,44 @@ std::complex<double> unkOverlap_pw::unkdotp_G(const ModulePW::PW_Basis_K* wfcpw,
2020 const int iband_R,
2121 const psi::Psi<std::complex <double >>* evc)
2222{
23-
24- std::complex <double > result (0.0 ,0.0 );
25- const int number_pw = wfcpw->npw ;
26- std::complex <double > *unk_L = new std::complex <double >[number_pw];
27- std::complex <double > *unk_R = new std::complex <double >[number_pw];
28- ModuleBase::GlobalFunc::ZEROS (unk_L,number_pw);
29- ModuleBase::GlobalFunc::ZEROS (unk_R,number_pw);
30-
31-
32- for (int igl = 0 ; igl < evc->get_ngk (ik_L); igl++)
33- {
34- unk_L[wfcpw->getigl2ig (ik_L,igl)] = evc[0 ](ik_L, iband_L, igl);
35- }
36-
37- for (int igl = 0 ; igl < evc->get_ngk (ik_R); igl++)
38- {
39- unk_R[wfcpw->getigl2ig (ik_R,igl)] = evc[0 ](ik_R, iband_R, igl);
40- }
41-
42-
43- for (int iG = 0 ; iG < number_pw; iG++)
44- {
45-
46- result = result + conj (unk_L[iG]) * unk_R[iG];
47-
48- }
4923
24+ std::complex <double > result (0.0 , 0.0 );
25+ const int number_pw = wfcpw->npw ;
26+ std::complex <double >* unk_L = new std::complex <double >[number_pw];
27+ std::complex <double >* unk_R = new std::complex <double >[number_pw];
28+ ModuleBase::GlobalFunc::ZEROS (unk_L, number_pw);
29+ ModuleBase::GlobalFunc::ZEROS (unk_R, number_pw);
30+
31+ for (int igl = 0 ; igl < evc->get_ngk (ik_L); igl++)
32+ {
33+ unk_L[wfcpw->getigl2ig (ik_L, igl)] = evc[0 ](ik_L, iband_L, igl);
34+ }
35+
36+ for (int igl = 0 ; igl < evc->get_ngk (ik_R); igl++)
37+ {
38+ unk_R[wfcpw->getigl2ig (ik_R, igl)] = evc[0 ](ik_R, iband_R, igl);
39+ }
40+
41+ for (int iG = 0 ; iG < number_pw; iG++)
42+ {
43+
44+ result = result + conj (unk_L[iG]) * unk_R[iG];
45+ }
5046
5147#ifdef __MPI
5248 // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1.
53- double in_date_real = result.real ();
54- double in_date_imag = result.imag ();
55- double out_date_real = 0.0 ;
56- double out_date_imag = 0.0 ;
57- MPI_Allreduce (&in_date_real , &out_date_real , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
58- MPI_Allreduce (&in_date_imag , &out_date_imag , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
59- result = std::complex <double >(out_date_real,out_date_imag);
49+ double in_date_real = result.real ();
50+ double in_date_imag = result.imag ();
51+ double out_date_real = 0.0 ;
52+ double out_date_imag = 0.0 ;
53+ MPI_Allreduce (&in_date_real, &out_date_real, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
54+ MPI_Allreduce (&in_date_imag, &out_date_imag, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
55+ result = std::complex <double >(out_date_real, out_date_imag);
6056#endif
6157
62- delete[] unk_L;
63- delete[] unk_R;
64- return result;
65-
66-
58+ delete[] unk_L;
59+ delete[] unk_R;
60+ return result;
6761}
6862
6963std::complex <double > unkOverlap_pw::unkdotp_G0 (const ModulePW::PW_Basis* rhopw,
@@ -75,24 +69,24 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
7569 const psi::Psi<std::complex <double >>* evc,
7670 const ModuleBase::Vector3<double > G)
7771{
78- // (1) set value
79- std::complex <double > result (0.0 ,0.0 );
72+ // (1) set value
73+ std::complex <double > result (0.0 , 0.0 );
8074 std::complex <double >* psi_r = new std::complex <double >[wfcpw->nmaxgr ];
8175 std::complex <double >* phase = new std::complex <double >[rhopw->nmaxgr ];
8276
8377 // get the phase value in realspace
8478 for (int ig = 0 ; ig < rhopw->nmaxgr ; ig++)
8579 {
86- ModuleBase::Vector3<double > delta_G = rhopw->gdirect [ig] - G;
87- if (delta_G.norm2 () < 1e-10 ) // rhopw->gdirect[ig] == G
88- {
89- phase[ig] = std::complex <double >(1.0 ,0.0 );
90- break ;
91- }
92- }
93-
94- // (2) fft and get value
95- rhopw->recip2real (phase, phase);
80+ ModuleBase::Vector3<double > delta_G = rhopw->gdirect [ig] - G;
81+ if (delta_G.norm2 () < 1e-10 ) // rhopw->gdirect[ig] == G
82+ {
83+ phase[ig] = std::complex <double >(1.0 , 0.0 );
84+ break ;
85+ }
86+ }
87+
88+ // (2) fft and get value
89+ rhopw->recip2real (phase, phase);
9690 wfcpw->recip2real (&evc[0 ](ik_L, iband_L, 0 ), psi_r, ik_L);
9791
9892 for (int ir = 0 ; ir < rhopw->nmaxgr ; ir++)
@@ -110,17 +104,17 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
110104
111105#ifdef __MPI
112106 // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1.
113- double in_date_real = result.real ();
114- double in_date_imag = result.imag ();
115- double out_date_real = 0.0 ;
116- double out_date_imag = 0.0 ;
117- MPI_Allreduce (&in_date_real , &out_date_real , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
118- MPI_Allreduce (&in_date_imag , &out_date_imag , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
119- result = std::complex <double >(out_date_real,out_date_imag);
107+ double in_date_real = result.real ();
108+ double in_date_imag = result.imag ();
109+ double out_date_real = 0.0 ;
110+ double out_date_imag = 0.0 ;
111+ MPI_Allreduce (&in_date_real, &out_date_real, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
112+ MPI_Allreduce (&in_date_imag, &out_date_imag, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
113+ result = std::complex <double >(out_date_real, out_date_imag);
120114#endif
121-
122- delete[] psi_r;
123- delete[] phase;
115+
116+ delete[] psi_r;
117+ delete[] phase;
124118 return result;
125119}
126120
@@ -133,18 +127,18 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf
133127 const int npwx,
134128 const psi::Psi<std::complex <double >>* evc)
135129{
136-
137- std::complex <double > result (0.0 ,0.0 );
130+
131+ std::complex <double > result (0.0 , 0.0 );
138132 const int number_pw = wfcpw->npw ;
139133 std::complex <double >* unk_L = new std::complex <double >[number_pw * PARAM.globalv .npol ];
140134 std::complex <double >* unk_R = new std::complex <double >[number_pw * PARAM.globalv .npol ];
141- ModuleBase::GlobalFunc::ZEROS (unk_L,number_pw* PARAM.globalv .npol );
142- ModuleBase::GlobalFunc::ZEROS (unk_R,number_pw* PARAM.globalv .npol );
143-
144- for (int i = 0 ; i < PARAM.globalv .npol ; i++)
145- {
146- for (int igl = 0 ; igl < evc->get_ngk (ik_L); igl++)
147- {
135+ ModuleBase::GlobalFunc::ZEROS (unk_L, number_pw * PARAM.globalv .npol );
136+ ModuleBase::GlobalFunc::ZEROS (unk_R, number_pw * PARAM.globalv .npol );
137+
138+ for (int i = 0 ; i < PARAM.globalv .npol ; i++)
139+ {
140+ for (int igl = 0 ; igl < evc->get_ngk (ik_L); igl++)
141+ {
148142 unk_L[wfcpw->getigl2ig (ik_L, igl) + i * number_pw] = evc[0 ](ik_L, iband_L, igl + i * npwx);
149143 }
150144
@@ -154,32 +148,29 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G(const ModulePW::PW_Basis_K* wf
154148 }
155149 }
156150
157- for (int iG = 0 ; iG < number_pw* PARAM.globalv .npol ; iG++)
158- {
151+ for (int iG = 0 ; iG < number_pw * PARAM.globalv .npol ; iG++)
152+ {
159153
160- result = result + conj (unk_L[iG]) * unk_R[iG];
154+ result = result + conj (unk_L[iG]) * unk_R[iG];
155+ }
161156
162- }
163-
164157#ifdef __MPI
165158 // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1.
166- double in_date_real = result.real ();
167- double in_date_imag = result.imag ();
168- double out_date_real = 0.0 ;
169- double out_date_imag = 0.0 ;
170- MPI_Allreduce (&in_date_real , &out_date_real , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
171- MPI_Allreduce (&in_date_imag , &out_date_imag , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
172- result = std::complex <double >(out_date_real,out_date_imag);
159+ double in_date_real = result.real ();
160+ double in_date_imag = result.imag ();
161+ double out_date_real = 0.0 ;
162+ double out_date_imag = 0.0 ;
163+ MPI_Allreduce (&in_date_real, &out_date_real, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
164+ MPI_Allreduce (&in_date_imag, &out_date_imag, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
165+ result = std::complex <double >(out_date_real, out_date_imag);
173166#endif
174167
175- delete[] unk_L;
176- delete[] unk_R;
177- return result;
178-
179-
168+ delete[] unk_L;
169+ delete[] unk_R;
170+ return result;
180171}
181172
182- // here G is in direct coordinate
173+ // here G is in direct coordinate
183174std::complex <double > unkOverlap_pw::unkdotp_soc_G0 (const ModulePW::PW_Basis* rhopw,
184175 const ModulePW::PW_Basis_K* wfcpw,
185176 const int ik_L,
@@ -189,32 +180,32 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
189180 const psi::Psi<std::complex <double >>* evc,
190181 const ModuleBase::Vector3<double > G)
191182{
192- // (1) set value
193- std::complex <double > result (0.0 ,0.0 );
194- std::complex <double > * phase =new std::complex <double >[rhopw->nmaxgr ];
183+ // (1) set value
184+ std::complex <double > result (0.0 , 0.0 );
185+ std::complex <double >* phase = new std::complex <double >[rhopw->nmaxgr ];
195186 std::complex <double >* psi_up = new std::complex <double >[wfcpw->nmaxgr ];
196187 std::complex <double >* psi_down = new std::complex <double >[wfcpw->nmaxgr ];
197188 const int npwx = wfcpw->npwk_max ;
198189
199190 // get the phase value in realspace
200191 for (int ig = 0 ; ig < rhopw->npw ; ig++)
201- {
202- if (rhopw->gdirect [ig] == G)
203- {
204- phase[ig] = std::complex <double >(1.0 ,0.0 );
205- break ;
206- }
207- }
208-
209- // (2) fft and get value
210- rhopw->recip2real (phase, phase);
192+ {
193+ if (rhopw->gdirect [ig] == G)
194+ {
195+ phase[ig] = std::complex <double >(1.0 , 0.0 );
196+ break ;
197+ }
198+ }
199+
200+ // (2) fft and get value
201+ rhopw->recip2real (phase, phase);
211202 wfcpw->recip2real (&evc[0 ](ik_L, iband_L, 0 ), psi_up, ik_L);
212203 wfcpw->recip2real (&evc[0 ](ik_L, iband_L, npwx), psi_down, ik_L);
213204
214205 for (int ir = 0 ; ir < wfcpw->nrxx ; ir++)
215206 {
216207 psi_up[ir] = psi_up[ir] * phase[ir];
217- psi_down[ir] = psi_down[ir] * phase[ir];
208+ psi_down[ir] = psi_down[ir] * phase[ir];
218209 }
219210
220211 // (3) calculate the overlap in ik_L and ik_R
@@ -223,27 +214,31 @@ std::complex<double> unkOverlap_pw::unkdotp_soc_G0(const ModulePW::PW_Basis* rho
223214
224215 for (int i = 0 ; i < PARAM.globalv .npol ; i++)
225216 {
226- for (int ig = 0 ; ig < evc->get_ngk (ik_R); ig++)
227- {
228- if ( i == 0 ) { result = result + conj ( psi_up[ig] ) * evc[0 ](ik_R, iband_R, ig);
229- }
230- if ( i == 1 ) { result = result + conj ( psi_down[ig] ) * evc[0 ](ik_R, iband_R, ig + npwx);
231- }
232- }
233- }
234-
217+ for (int ig = 0 ; ig < evc->get_ngk (ik_R); ig++)
218+ {
219+ if (i == 0 )
220+ {
221+ result = result + conj (psi_up[ig]) * evc[0 ](ik_R, iband_R, ig);
222+ }
223+ if (i == 1 )
224+ {
225+ result = result + conj (psi_down[ig]) * evc[0 ](ik_R, iband_R, ig + npwx);
226+ }
227+ }
228+ }
229+
235230#ifdef __MPI
236231 // note: the mpi uses MPI_COMMON_WORLD,so you must make the GlobalV::KPAR = 1.
237- double in_date_real = result.real ();
238- double in_date_imag = result.imag ();
239- double out_date_real = 0.0 ;
240- double out_date_imag = 0.0 ;
241- MPI_Allreduce (&in_date_real , &out_date_real , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
242- MPI_Allreduce (&in_date_imag , &out_date_imag , 1 , MPI_DOUBLE , MPI_SUM , POOL_WORLD);
243- result = std::complex <double >(out_date_real,out_date_imag);
232+ double in_date_real = result.real ();
233+ double in_date_imag = result.imag ();
234+ double out_date_real = 0.0 ;
235+ double out_date_imag = 0.0 ;
236+ MPI_Allreduce (&in_date_real, &out_date_real, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
237+ MPI_Allreduce (&in_date_imag, &out_date_imag, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
238+ result = std::complex <double >(out_date_real, out_date_imag);
244239#endif
245-
246- delete[] psi_up;
247- delete[] psi_down;
240+
241+ delete[] psi_up;
242+ delete[] psi_down;
248243 return result;
249244}
0 commit comments