@@ -465,8 +465,19 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
465465 Stochastic_hchi& stohchi = stoiter.stohchi ;
466466 const int npwx = GlobalC::wf.npwx ;
467467
468- double * spolyv = new double [nche_dos];
469- ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos);
468+ double * spolyv = nullptr ;
469+ std::complex <double > *allorderchi = nullptr ;
470+ if (stoiter.method == 1 )
471+ {
472+ spolyv = new double [nche_dos];
473+ ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos);
474+ }
475+ else
476+ {
477+ spolyv = new double [nche_dos*nche_dos];
478+ ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos*nche_dos);
479+ allorderchi = new std::complex <double > [this ->stowf .nchip_max * npwx * nche_dos];
480+ }
470481 cout<<" 1. TracepolyA:" <<endl;
471482 for (int ik = 0 ;ik < nk;ik++)
472483 {
@@ -477,19 +488,37 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
477488 }
478489 stohchi.current_ik = ik;
479490 const int npw = GlobalC::kv.ngk [ik];
480- const int nchip = this ->stowf .nchip [ik];
491+ const int nchipk = this ->stowf .nchip [ik];
481492
482493 complex <double > * pchi;
483494 if (GlobalV::NBANDS > 0 )
484495 pchi = stowf.chiortho [ik].c ;
485496 else
486497 pchi = stowf.chi0 [ik].c ;
487- che.tracepolyA (&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchip);
488- for (int i = 0 ; i < nche_dos ; ++i)
498+ if (stoiter.method == 1 )
499+ {
500+ che.tracepolyA (&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk);
501+ for (int i = 0 ; i < nche_dos ; ++i)
502+ {
503+ spolyv[i] += che.polytrace [i] * GlobalC::kv.wk [ik] / 2 ;
504+ }
505+ }
506+ else
489507 {
490- spolyv[i] += che.polytrace [i] * GlobalC::kv.wk [ik] / 2 ;
508+ ModuleBase::GlobalFunc::ZEROS (allorderchi, this ->stowf .nchip_max * npwx * nche_dos);
509+ che.calpolyvec_complex (&stohchi, &Stochastic_hchi::hchi_norm, pchi, allorderchi, npw, npwx, nchipk);
510+ double * vec_all= (double *) allorderchi;
511+ char trans = ' T' ;
512+ char normal = ' N' ;
513+ double one = 1 ;
514+ int LDA = npwx * nchipk * 2 ;
515+ int M = npwx * nchipk * 2 ;
516+ int N = nche_dos;
517+ double kweight = GlobalC::kv.wk [ik] / 2 ;
518+ dgemm_ (&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
491519 }
492520 }
521+ if (stoiter.method == 2 ) delete[] allorderchi;
493522
494523 string dosfile = GlobalV::global_out_dir+" DOS1_smearing.dat" ;
495524 ofstream ofsdos (dosfile.c_str ());
@@ -498,17 +527,26 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
498527 ModuleBase::GlobalFunc::ZEROS (dos,ndos);
499528 stoiter.stofunc .sigma = sigmain / ModuleBase::Ry_to_eV;
500529 double sum = 0 ;
501- double error = 0 ;
530+ double maxerror = 0 ;
502531 ofsdos<<setw (8 )<<" ## E(eV) " <<setw (20 )<<" dos(eV^-1)" <<setw (20 )<<" sum" <<setw (20 )<<" Error(eV^-1)" <<endl;
503532 cout<<" 2. Dos:" <<endl;
504533 int n10 = ndos/10 ;
505534 int percent = 10 ;
506535 for (int ie = 0 ; ie < ndos; ++ie)
507536 {
508- stoiter.stofunc .targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
509- che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::ngauss);
510537 double KS_dos = 0 ;
511- double sto_dos = BlasConnector::dot (nche_dos,che.coef_real ,1 ,spolyv,1 );
538+ double sto_dos = 0 ;
539+ stoiter.stofunc .targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
540+ if (stoiter.method == 1 )
541+ {
542+ che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::ngauss);
543+ sto_dos = BlasConnector::dot (nche_dos,che.coef_real ,1 ,spolyv,1 );
544+ }
545+ else
546+ {
547+ che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::nroot_gauss);
548+ sto_dos = stoiter.vTMv (che.coef_real ,spolyv,nche_dos);
549+ }
512550 if (GlobalV::NBANDS > 0 )
513551 {
514552 for (int ik = 0 ; ik < nk; ++ik)
@@ -525,11 +563,23 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
525563 MPI_Allreduce (MPI_IN_PLACE, &KS_dos, 1 , MPI_DOUBLE, MPI_SUM , STO_WORLD);
526564 MPI_Allreduce (MPI_IN_PLACE, &sto_dos, 1 , MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
527565#endif
528- double tmpre = che.coef_real [nche_dos-1 ] * spolyv[nche_dos-1 ];
566+ double tmpre = 0 ;
567+ if (stoiter.method == 1 )
568+ {
569+ tmpre = che.coef_real [nche_dos-1 ] * spolyv[nche_dos-1 ];
570+ }
571+ else
572+ {
573+ const int norder = nche_dos;
574+ double last_coef = che.coef_real [norder-1 ];
575+ double last_spolyv = spolyv[norder*norder - 1 ];
576+ tmpre = last_coef *(BlasConnector::dot (norder,che.coef_real ,1 ,spolyv+norder*(norder-1 ),1 )
577+ + BlasConnector::dot (norder,che.coef_real ,1 ,spolyv+norder-1 ,norder)-last_coef*last_spolyv);
578+ }
529579#ifdef __MPI
530580 MPI_Allreduce (MPI_IN_PLACE, &tmpre, 1 , MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
531581#endif
532- if (error < tmpre) error = tmpre;
582+ if (maxerror < tmpre) maxerror = tmpre;
533583 dos[ie] = (KS_dos + sto_dos) / ModuleBase::Ry_to_eV;
534584 sum += dos[ie];
535585 ofsdos <<setw (8 )<< emin + ie * de <<setw (20 )<<dos[ie]<<setw (20 )<<sum * de <<setw (20 ) <<tmpre <<endl;
@@ -541,7 +591,7 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
541591 }
542592 cout<<endl;
543593 cout<<" Finish DOS" <<endl;
544- cout<<scientific<<" DOS max Chebyshev Error: " <<error <<endl;
594+ cout<<scientific<<" DOS max Chebyshev Error: " <<maxerror <<endl;
545595 delete[] dos;
546596 delete[] spolyv;
547597 return ;
0 commit comments