Skip to content

Commit 4c0b28c

Browse files
authored
Fix & Perf: wrong sdft and sKG results (#1680)
* fix: make Makefile available * update docs for makefile * add more spaces for stress and force output * do not change the length of screen output * fix and performance: sdft and sKG 1. fix wrong sdft results when nbands>0 && bndpar>0 && total kpoints > kpar 2. fix wrong sKG results when scf_nmax == 0 3. accelerate sKG
1 parent 2d36fda commit 4c0b28c

File tree

13 files changed

+2202
-2148
lines changed

13 files changed

+2202
-2148
lines changed

source/module_esolver/esolver_ks_pw.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef ESOLVER_KS_PW_H
22
#define ESOLVER_KS_PW_H
33
#include "./esolver_ks.h"
4+
#include "module_hamilt/ks_pw/velocity_pw.h"
45
// #include "Basis_PW.h"
56
// #include "Estate_PW.h"
67
// #include "Hamilton_PW.h"
@@ -27,6 +28,8 @@ namespace ModuleESolver
2728
//calculate conductivities with Kubo-Greenwood formula
2829
void KG(const int nche_KG, const FPTYPE fwhmin, const FPTYPE wcut,
2930
const FPTYPE dw_in, const int times, ModuleBase::matrix& wg);
31+
void jjcorr_ks(const int ik, const int nt, const double dt, ModuleBase::matrix& wg, hamilt::Velocity& velop,
32+
FPTYPE* ct11, FPTYPE* ct12, FPTYPE* ct22);
3033

3134
protected:
3235
virtual void beforescf(const int istep) override;

source/module_esolver/esolver_ks_pw_tool.cpp

Lines changed: 79 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "module_base/global_variable.h"
44
#include "src_pw/global.h"
55
#include "src_pw/occupy.h"
6-
#include "module_hamilt/ks_pw/velocity_pw.h"
76

87
namespace ModuleESolver
98
{
@@ -31,8 +30,6 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const FPTYPE fwhmin, c
3130
// KS conductivity
3231
//-----------------------------------------------------------
3332
cout << "Calculating conductivity..." << endl;
34-
char transn = 'N';
35-
char transc = 'C';
3633
int nw = ceil(wcut / dw_in);
3734
FPTYPE dw = dw_in / ModuleBase::Ry_to_eV; // converge unit in eV to Ry
3835
FPTYPE sigma = fwhmin / TWOSQRT2LN2 / ModuleBase::Ry_to_eV;
@@ -43,11 +40,6 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const FPTYPE fwhmin, c
4340
assert(nw >= 1);
4441
assert(nt >= 1);
4542
const int nk = GlobalC::kv.nks;
46-
const int ndim = 3;
47-
const int npwx = GlobalC::wf.npwx;
48-
const FPTYPE tpiba = GlobalC::ucell.tpiba;
49-
const int nbands = GlobalV::NBANDS;
50-
const FPTYPE ef = GlobalC::en.ef;
5143

5244
FPTYPE *ct11 = new FPTYPE[nt];
5345
FPTYPE *ct12 = new FPTYPE[nt];
@@ -60,78 +52,13 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const FPTYPE fwhmin, c
6052
for (int ik = 0; ik < nk; ++ik)
6153
{
6254
velop.init(ik);
63-
const int npw = GlobalC::kv.ngk[ik];
64-
complex<FPTYPE> *levc = &(this->psi[0](ik, 0, 0));
65-
complex<FPTYPE> *prevc = new complex<FPTYPE>[3 * npwx * nbands];
66-
// px|right>
67-
velop.act(this->psi, nbands*GlobalV::NPOL, levc, prevc);
68-
for (int id = 0; id < ndim; ++id)
69-
{
70-
this->p_hamilt->updateHk(ik);
71-
complex<FPTYPE> *pij = new complex<FPTYPE>[nbands * nbands];
72-
zgemm_(&transc,
73-
&transn,
74-
&nbands,
75-
&nbands,
76-
&npw,
77-
&ModuleBase::ONE,
78-
levc,
79-
&npwx,
80-
prevc + id * npwx * nbands,
81-
&npwx,
82-
&ModuleBase::ZERO,
83-
pij,
84-
&nbands);
85-
#ifdef __MPI
86-
MPI_Allreduce(MPI_IN_PLACE, pij, 2 * nbands * nbands, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
87-
#endif
88-
int ntper = nt / GlobalV::NPROC_IN_POOL;
89-
int itstart = ntper * GlobalV::RANK_IN_POOL;
90-
if (nt % GlobalV::NPROC_IN_POOL > GlobalV::RANK_IN_POOL)
91-
{
92-
ntper++;
93-
itstart += GlobalV::RANK_IN_POOL;
94-
}
95-
else
96-
{
97-
itstart += nt % GlobalV::NPROC_IN_POOL;
98-
}
99-
100-
for (int it = itstart; it < itstart + ntper; ++it)
101-
// for(int it = 0 ; it < nt; ++it)
102-
{
103-
FPTYPE tmct11 = 0;
104-
FPTYPE tmct12 = 0;
105-
FPTYPE tmct22 = 0;
106-
FPTYPE *enb = &(this->pelec->ekb(ik, 0));
107-
for (int ib = 0; ib < nbands; ++ib)
108-
{
109-
FPTYPE ei = enb[ib];
110-
FPTYPE fi = wg(ik, ib);
111-
for (int jb = ib + 1; jb < nbands; ++jb)
112-
{
113-
FPTYPE ej = enb[jb];
114-
FPTYPE fj = wg(ik, jb);
115-
FPTYPE tmct = sin((ej - ei) * (it)*dt) * (fi - fj) * norm(pij[ib * nbands + jb]);
116-
tmct11 += tmct;
117-
tmct12 += -tmct * ((ei + ej) / 2 - ef);
118-
tmct22 += tmct * pow((ei + ej) / 2 - ef, 2);
119-
}
120-
}
121-
ct11[it] += tmct11 / 2.0;
122-
ct12[it] += tmct12 / 2.0;
123-
ct22[it] += tmct22 / 2.0;
124-
}
125-
delete[] pij;
126-
}
127-
delete[] prevc;
55+
jjcorr_ks(ik, nt, dt, wg, velop, ct11,ct12,ct22);
12856
}
12957
#ifdef __MPI
13058
MPI_Allreduce(MPI_IN_PLACE, ct11, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
13159
MPI_Allreduce(MPI_IN_PLACE, ct12, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
13260
MPI_Allreduce(MPI_IN_PLACE, ct22, nt, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
133-
#endif
134-
61+
#endif
13562
//------------------------------------------------------------------
13663
// Output
13764
//------------------------------------------------------------------
@@ -144,6 +71,83 @@ void ESolver_KS_PW<FPTYPE, Device>::KG(const int nche_KG, const FPTYPE fwhmin, c
14471
delete[] ct22;
14572
}
14673

74+
template <typename FPTYPE, typename Device>
75+
void ESolver_KS_PW<FPTYPE, Device>:: jjcorr_ks(const int ik, const int nt, const double dt, ModuleBase::matrix& wg, hamilt::Velocity &velop,
76+
FPTYPE* ct11, FPTYPE* ct12, FPTYPE* ct22)
77+
{
78+
char transn = 'N';
79+
char transc = 'C';
80+
const int ndim = 3;
81+
const int npwx = GlobalC::wf.npwx;
82+
const FPTYPE tpiba = GlobalC::ucell.tpiba;
83+
const int nbands = GlobalV::NBANDS;
84+
const FPTYPE ef = this->pelec->ef;
85+
const int npw = GlobalC::kv.ngk[ik];
86+
std::complex<FPTYPE> *levc = &(this->psi[0](ik, 0, 0));
87+
complex<FPTYPE> *prevc = new complex<FPTYPE>[3 * npwx * nbands];
88+
// px|right>
89+
velop.act(this->psi, nbands*GlobalV::NPOL, levc, prevc);
90+
for (int id = 0; id < ndim; ++id)
91+
{
92+
complex<FPTYPE> *pij = new complex<FPTYPE>[nbands * nbands];
93+
zgemm_(&transc,
94+
&transn,
95+
&nbands,
96+
&nbands,
97+
&npw,
98+
&ModuleBase::ONE,
99+
levc,
100+
&npwx,
101+
prevc + id * npwx * nbands,
102+
&npwx,
103+
&ModuleBase::ZERO,
104+
pij,
105+
&nbands);
106+
#ifdef __MPI
107+
MPI_Allreduce(MPI_IN_PLACE, pij, 2 * nbands * nbands, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
108+
#endif
109+
int ntper = nt / GlobalV::NPROC_IN_POOL;
110+
int itstart = ntper * GlobalV::RANK_IN_POOL;
111+
if (nt % GlobalV::NPROC_IN_POOL > GlobalV::RANK_IN_POOL)
112+
{
113+
ntper++;
114+
itstart += GlobalV::RANK_IN_POOL;
115+
}
116+
else
117+
{
118+
itstart += nt % GlobalV::NPROC_IN_POOL;
119+
}
120+
121+
for (int it = itstart; it < itstart + ntper; ++it)
122+
{
123+
FPTYPE tmct11 = 0;
124+
FPTYPE tmct12 = 0;
125+
FPTYPE tmct22 = 0;
126+
FPTYPE *enb = &(this->pelec->ekb(ik, 0));
127+
for (int ib = 0; ib < nbands; ++ib)
128+
{
129+
FPTYPE ei = enb[ib];
130+
FPTYPE fi = wg(ik, ib);
131+
for (int jb = ib + 1; jb < nbands; ++jb)
132+
{
133+
FPTYPE ej = enb[jb];
134+
FPTYPE fj = wg(ik, jb);
135+
FPTYPE tmct = sin((ej - ei) * (it)*dt) * (fi - fj) * norm(pij[ib * nbands + jb]);
136+
tmct11 += tmct;
137+
tmct12 += -tmct * ((ei + ej) / 2 - ef);
138+
tmct22 += tmct * pow((ei + ej) / 2 - ef, 2);
139+
}
140+
}
141+
ct11[it] += tmct11 / 2.0;
142+
ct12[it] += tmct12 / 2.0;
143+
ct22[it] += tmct22 / 2.0;
144+
}
145+
delete[] pij;
146+
}
147+
delete[] prevc;
148+
return;
149+
}
150+
147151
template <typename FPTYPE, typename Device>
148152
void ESolver_KS_PW<FPTYPE, Device>::calcondw(const int nt,
149153
const FPTYPE dt,

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ void ESolver_SDFT_PW::postprocess()
178178
hsolver::DiagoIterAssist<double>::need_subspace = false;
179179
this->phsol->solve(this->p_hamilt, this->psi[0], this->pelec,this->stowf,istep, iter, GlobalV::KS_SOLVER, true);
180180
((hsolver::HSolverPW_SDFT*)phsol)->stoiter.cleanchiallorder();//release lots of memories
181+
GlobalC::en.ef = this->pelec->ef; //Temporary: Please use this->pelec->ef. GlobalC::en.ef is not recommended.
181182
}
182183
int nche_test = 0;
183184
if(INPUT.cal_cond) nche_test = std::max(nche_test, INPUT.cond_nche);

0 commit comments

Comments
 (0)