Skip to content

Commit 3bf4613

Browse files
committed
optimize Stochastic DOS of method_sto = 2
1 parent 6b44d70 commit 3bf4613

File tree

7 files changed

+95
-28
lines changed

7 files changed

+95
-28
lines changed

source/module_esolver/esolver_sdft_pw_tool.cpp

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,19 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
465465
Stochastic_hchi& stohchi = stoiter.stohchi;
466466
const int npwx = GlobalC::wf.npwx;
467467

468-
double * spolyv = new double [nche_dos];
469-
ModuleBase::GlobalFunc::ZEROS(spolyv, nche_dos);
468+
double * spolyv = nullptr;
469+
std::complex<double> *allorderchi = nullptr;
470+
if(stoiter.method == 1)
471+
{
472+
spolyv = new double [nche_dos];
473+
ModuleBase::GlobalFunc::ZEROS(spolyv, nche_dos);
474+
}
475+
else
476+
{
477+
spolyv = new double [nche_dos*nche_dos];
478+
ModuleBase::GlobalFunc::ZEROS(spolyv, nche_dos*nche_dos);
479+
allorderchi = new std::complex<double> [this->stowf.nchip_max * npwx * nche_dos];
480+
}
470481
cout<<"1. TracepolyA:"<<endl;
471482
for (int ik = 0;ik < nk;ik++)
472483
{
@@ -477,19 +488,37 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
477488
}
478489
stohchi.current_ik = ik;
479490
const int npw = GlobalC::kv.ngk[ik];
480-
const int nchip = this->stowf.nchip[ik];
491+
const int nchipk = this->stowf.nchip[ik];
481492

482493
complex<double> * pchi;
483494
if(GlobalV::NBANDS > 0)
484495
pchi = stowf.chiortho[ik].c;
485496
else
486497
pchi = stowf.chi0[ik].c;
487-
che.tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchip);
488-
for(int i = 0 ; i < nche_dos ; ++i)
498+
if(stoiter.method == 1)
499+
{
500+
che.tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk);
501+
for(int i = 0 ; i < nche_dos ; ++i)
502+
{
503+
spolyv[i] += che.polytrace[i] * GlobalC::kv.wk[ik] / 2 ;
504+
}
505+
}
506+
else
489507
{
490-
spolyv[i] += che.polytrace[i] * GlobalC::kv.wk[ik] / 2 ;
508+
ModuleBase::GlobalFunc::ZEROS(allorderchi, this->stowf.nchip_max * npwx * nche_dos);
509+
che.calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, pchi, allorderchi, npw, npwx, nchipk);
510+
double* vec_all= (double *) allorderchi;
511+
char trans = 'T';
512+
char normal = 'N';
513+
double one = 1;
514+
int LDA = npwx * nchipk * 2;
515+
int M = npwx * nchipk * 2;
516+
int N = nche_dos;
517+
double kweight = GlobalC::kv.wk[ik] / 2;
518+
dgemm_(&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
491519
}
492520
}
521+
if(stoiter.method == 2) delete[] allorderchi;
493522

494523
string dosfile = GlobalV::global_out_dir+"DOS1_smearing.dat";
495524
ofstream ofsdos(dosfile.c_str());
@@ -498,17 +527,26 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
498527
ModuleBase::GlobalFunc::ZEROS(dos,ndos);
499528
stoiter.stofunc.sigma = sigmain / ModuleBase::Ry_to_eV;
500529
double sum = 0;
501-
double error = 0;
530+
double maxerror = 0;
502531
ofsdos<<setw(8)<<"## E(eV) "<<setw(20)<<"dos(eV^-1)"<<setw(20)<<"sum"<<setw(20)<<"Error(eV^-1)"<<endl;
503532
cout<<"2. Dos:"<<endl;
504533
int n10 = ndos/10;
505534
int percent = 10;
506535
for(int ie = 0; ie < ndos; ++ie)
507536
{
508-
stoiter.stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
509-
che.calcoef_real(&stoiter.stofunc, &Sto_Func<double>::ngauss);
510537
double KS_dos = 0;
511-
double sto_dos = BlasConnector::dot(nche_dos,che.coef_real,1,spolyv,1);
538+
double sto_dos = 0;
539+
stoiter.stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
540+
if(stoiter.method == 1)
541+
{
542+
che.calcoef_real(&stoiter.stofunc, &Sto_Func<double>::ngauss);
543+
sto_dos = BlasConnector::dot(nche_dos,che.coef_real,1,spolyv,1);
544+
}
545+
else
546+
{
547+
che.calcoef_real(&stoiter.stofunc, &Sto_Func<double>::nroot_gauss);
548+
sto_dos = stoiter.vTMv(che.coef_real,spolyv,nche_dos);
549+
}
512550
if(GlobalV::NBANDS > 0)
513551
{
514552
for(int ik = 0; ik < nk; ++ik)
@@ -525,11 +563,23 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
525563
MPI_Allreduce(MPI_IN_PLACE, &KS_dos, 1, MPI_DOUBLE, MPI_SUM , STO_WORLD);
526564
MPI_Allreduce(MPI_IN_PLACE, &sto_dos, 1, MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
527565
#endif
528-
double tmpre = che.coef_real[nche_dos-1] * spolyv[nche_dos-1];
566+
double tmpre = 0;
567+
if(stoiter.method == 1)
568+
{
569+
tmpre = che.coef_real[nche_dos-1] * spolyv[nche_dos-1];
570+
}
571+
else
572+
{
573+
const int norder = nche_dos;
574+
double last_coef = che.coef_real[norder-1];
575+
double last_spolyv = spolyv[norder*norder - 1];
576+
tmpre = last_coef *(BlasConnector::dot(norder,che.coef_real,1,spolyv+norder*(norder-1),1)
577+
+ BlasConnector::dot(norder,che.coef_real,1,spolyv+norder-1,norder)-last_coef*last_spolyv);
578+
}
529579
#ifdef __MPI
530580
MPI_Allreduce(MPI_IN_PLACE, &tmpre, 1, MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
531581
#endif
532-
if(error < tmpre) error = tmpre;
582+
if(maxerror < tmpre) maxerror = tmpre;
533583
dos[ie] = (KS_dos + sto_dos) / ModuleBase::Ry_to_eV;
534584
sum += dos[ie];
535585
ofsdos <<setw(8)<< emin + ie * de <<setw(20)<<dos[ie]<<setw(20)<<sum * de <<setw(20) <<tmpre <<endl;
@@ -541,7 +591,7 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
541591
}
542592
cout<<endl;
543593
cout<<"Finish DOS"<<endl;
544-
cout<<scientific<<"DOS max Chebyshev Error: "<<error<<endl;
594+
cout<<scientific<<"DOS max Chebyshev Error: "<<maxerror<<endl;
545595
delete[] dos;
546596
delete[] spolyv;
547597
return;

source/src_pw/sto_func.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ REAL Sto_Func<REAL>:: nxfd(REAL rawe)
5959
REAL DeltaE = (Emax - Emin)/2;
6060
REAL e = rawe * DeltaE + Ebar;
6161
REAL ne_mu = (e - mu) / this->tem ;
62-
if(ne_mu > 40)
62+
if(ne_mu > 36)
6363
return 0;
6464
else
6565
return e / (1 + exp(ne_mu));
@@ -103,9 +103,9 @@ REAL Sto_Func<REAL>:: n_root_fdlnfd(REAL rawe)
103103
REAL Ebar = (Emin + Emax)/2;
104104
REAL DeltaE = (Emax - Emin)/2;
105105
REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem ;
106-
if(ne_mu > 36)
106+
if(ne_mu > 72)
107107
return 0;
108-
else if(ne_mu < -36)
108+
else if(ne_mu < -72)
109109
return 0;
110110
else
111111
{
@@ -170,12 +170,25 @@ REAL Sto_Func<REAL>::ngauss(REAL rawe)
170170
REAL DeltaE = (Emax - Emin)/2;
171171
REAL e = rawe * DeltaE + Ebar;
172172
REAL a = pow((targ_e-e),2)/2.0/pow(sigma,2);
173-
if(a > 72)
173+
if(a > 32)
174174
return 0;
175175
else
176176
return exp(-a) /sqrt(TWOPI) / sigma ;
177177
}
178178

179+
template<typename REAL>
180+
REAL Sto_Func<REAL>::nroot_gauss(REAL rawe)
181+
{
182+
REAL Ebar = (Emin + Emax)/2;
183+
REAL DeltaE = (Emax - Emin)/2;
184+
REAL e = rawe * DeltaE + Ebar;
185+
REAL a = pow((targ_e-e),2)/4.0/pow(sigma,2);
186+
if(a > 32)
187+
return 0;
188+
else
189+
return exp(-a) /sqrt(sqrt(TWOPI) * sigma) ;
190+
}
191+
179192
//we only have two examples: double and float.
180193
template class Sto_Func<double>;
181194
#ifdef __MIX_PRECISION

source/src_pw/sto_func.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Sto_Func
3333
REAL targ_e;
3434
REAL gauss(REAL e);
3535
REAL ngauss(REAL e);
36+
REAL nroot_gauss(REAL e);
3637

3738
};
3839

source/src_pw/sto_iter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "global.h"
1010
#include "occupy.h"
1111

12-
double vTMv(const double *v, const double * M, const int n)
12+
double Stochastic_Iter::vTMv(const double *v, const double * M, const int n)
1313
{
1414
const char normal = 'N';
1515
const double one = 1;
@@ -193,7 +193,7 @@ void Stochastic_Iter::check_precision(const double ref, const double thr, const
193193
const int norder = p_che->norder;
194194
double last_coef = p_che->coef_real[norder-1];
195195
double last_spolyv = spolyv[norder*norder - 1];
196-
error += last_coef *(BlasConnector::dot(norder,p_che->coef_real,1,spolyv+norder*(norder-1),1)
196+
error = last_coef *(BlasConnector::dot(norder,p_che->coef_real,1,spolyv+norder*(norder-1),1)
197197
+ BlasConnector::dot(norder,p_che->coef_real,1,spolyv+norder-1,norder)-last_coef*last_spolyv);
198198
}
199199

source/src_pw/sto_iter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class Stochastic_Iter
6666
void calPn(const int& ik, Stochastic_WF& stowf);
6767
//cal Tnchi = \sum_n C_n*T_n(\hat{h})|\chi>
6868
void calTnchi_ik(const int& ik, Stochastic_WF& stowf);
69+
//cal v^T*M*v
70+
double vTMv(const double *v, const double * M, const int n);
6971

7072
};
7173

source/src_pw/sto_wf.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,10 @@ Stochastic_WF::Stochastic_WF()
1818

1919
Stochastic_WF::~Stochastic_WF()
2020
{
21-
if (chi0 != nullptr)
22-
delete[] chi0;
23-
if (shchi != nullptr)
24-
delete[] shchi;
25-
if (chiortho != nullptr)
26-
delete[] chiortho;
27-
if (nchip != nullptr)
28-
delete[] nchip;
21+
delete[] chi0;
22+
delete[] shchi;
23+
delete[] chiortho;
24+
delete[] nchip;
2925
}
3026

3127
void Stochastic_WF::init(const int nks_in)
@@ -83,6 +79,7 @@ void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in)
8379
stowf.chi0[ik].c[i] = 1.0 / sqrt(double(nchi));
8480
}
8581
}
82+
stowf.nchip_max = tmpnchip;
8683
}
8784

8885
void Update_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in)
@@ -153,6 +150,7 @@ void Init_Com_Orbitals(Stochastic_WF& stowf, K_Vectors& kv)
153150
++tmpnchip;
154151
stowf.nchip[ik] = tmpnchip;
155152
stowf.chi0[ik].create(tmpnchip, ndim, true);
153+
stowf.nchip_max = std::max(tmpnchip,stowf.nchip_max);
156154

157155
const int re = totnpw[ik] % ngroup;
158156
int ip = 0, ig0 = 0;
@@ -195,7 +193,9 @@ void Init_Com_Orbitals(Stochastic_WF& stowf, K_Vectors& kv)
195193
const int ndim = GlobalC::wf.npwx;
196194
for (int ik = 0; ik < kv.nks; ++ik)
197195
{
196+
stowf.nchip[ik] = ndim;
198197
stowf.chi0[ik].create(stowf.nchip[ik], ndim, true);
198+
stowf.nchip_max = ndim;
199199
for (int ichi = 0; ichi < kv.ngk[ik]; ++ichi)
200200
{
201201
stowf.chi0[ik](ichi, ichi) = 1;

source/src_pw/sto_wf.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class Stochastic_WF
2323
ModuleBase::ComplexMatrix* chiortho; // stochastic wavefunctions after in reciprocal space orthogonalized with KS wavefunctions
2424
ModuleBase::ComplexMatrix* shchi; // sqrt(f(H))|chi>
2525
int nchi; // Total number of stochatic obitals
26-
int *nchip; // The number of stochatic obitals in current process of each k point.
26+
int *nchip; // The number of stochatic orbitals in current process of each k point.
27+
int nchip_max = 0; // Max number of stochastic orbitals among all k points.
2728
int nks; //number of k-points
2829

2930
int nbands_diag; // number of bands obtained from diagonalization

0 commit comments

Comments
 (0)