@@ -56,7 +56,7 @@ void ESolver_SDFT_PW::check_che(const int nche_in)
5656 while (1 )
5757 {
5858 bool converge;
59- converge= chetest.checkconverge (&stohchi, &Stochastic_hchi::hchi_reciprocal ,
59+ converge= chetest.checkconverge (&stohchi, &Stochastic_hchi::hchi_norm ,
6060 pchi, npw, stohchi.Emax , stohchi.Emin , 5.0 );
6161
6262 if (!converge)
@@ -263,10 +263,10 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
263263 ModuleBase::GlobalFunc::COPYARRAY (hj1sfpsi_out, j2sfpsi.get_pointer (), ndim*totbands_per*npwx);
264264
265265 /*
266- // stohchi.hchi_reciprocal (psi0.get_pointer(), hpsi0.get_pointer(), totbands_per);
267- // stohchi.hchi_reciprocal (sfpsi0.get_pointer(), hsfpsi0.get_pointer(), totbands_per);
268- // stohchi.hchi_reciprocal (j1psi.get_pointer(), j2psi.get_pointer(), ndim*totbands_per);
269- // stohchi.hchi_reciprocal (j1sfpsi.get_pointer(), j2sfpsi.get_pointer(), ndim*totbands_per);
266+ // stohchi.hchi_norm (psi0.get_pointer(), hpsi0.get_pointer(), totbands_per);
267+ // stohchi.hchi_norm (sfpsi0.get_pointer(), hsfpsi0.get_pointer(), totbands_per);
268+ // stohchi.hchi_norm (j1psi.get_pointer(), j2psi.get_pointer(), ndim*totbands_per);
269+ // stohchi.hchi_norm (j1sfpsi.get_pointer(), j2sfpsi.get_pointer(), ndim*totbands_per);
270270 // double Ebar = (stohchi.Emin + stohchi.Emax)/2;
271271 // double DeltaE = (stohchi.Emax - stohchi.Emin)/2;
272272 // for(int ib = 0 ; ib < totbands_per ; ++ib)
@@ -307,8 +307,8 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
307307
308308 // (1-f)
309309 che.calcoef_real (&stoiter.stofunc ,&Sto_Func<double >::n_fd);
310- che.calfinalvec_real (&stohchi, &Stochastic_hchi::hchi_reciprocal , j1sfpsi.get_pointer (), j1sfpsi.get_pointer (), npw, npwx, totbands_per*ndim);
311- che.calfinalvec_real (&stohchi, &Stochastic_hchi::hchi_reciprocal , j2sfpsi.get_pointer (), j2sfpsi.get_pointer (), npw, npwx, totbands_per*ndim);
310+ che.calfinalvec_real (&stohchi, &Stochastic_hchi::hchi_norm , j1sfpsi.get_pointer (), j1sfpsi.get_pointer (), npw, npwx, totbands_per*ndim);
311+ che.calfinalvec_real (&stohchi, &Stochastic_hchi::hchi_norm , j2sfpsi.get_pointer (), j2sfpsi.get_pointer (), npw, npwx, totbands_per*ndim);
312312
313313 psi::Psi<std::complex <double >> *p_j1psi = &j1psi;
314314 psi::Psi<std::complex <double >> *p_j2psi = &j2psi;
@@ -361,9 +361,9 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
361361 }
362362
363363 // exp(iHdt)|chi>
364- chet.calfinalvec_complex (&stohchi, &Stochastic_hchi::hchi_reciprocal , &exppsi (ksbandper,0 ), &exppsi (ksbandper,0 ), npw, npwx, nchip);
364+ chet.calfinalvec_complex (&stohchi, &Stochastic_hchi::hchi_norm , &exppsi (ksbandper,0 ), &exppsi (ksbandper,0 ), npw, npwx, nchip);
365365 // exp(-iHdt)|shchi>
366- chet2.calfinalvec_complex (&stohchi, &Stochastic_hchi::hchi_reciprocal , &expsfpsi (ksbandper,0 ), &expsfpsi (ksbandper,0 ), npw, npwx, nchip);
366+ chet2.calfinalvec_complex (&stohchi, &Stochastic_hchi::hchi_norm , &expsfpsi (ksbandper,0 ), &expsfpsi (ksbandper,0 ), npw, npwx, nchip);
367367 psi::Psi<std::complex <double >> *p_exppsi = &exppsi;
368368#ifdef __MPI
369369 psi::Psi<std::complex <double >> exppsi_tot;
@@ -454,53 +454,111 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
454454 ModuleBase::timer::tick (this ->classname ," sKG" );
455455}
456456
457- void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de)
457+ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de, const int npart )
458458{
459- cout<<" Calculating Dos...." <<endl;
459+ cout<<" =========================" <<endl;
460+ cout<<" ###Calculating Dos....###" <<endl;
461+ cout<<" =========================" <<endl;
460462 ModuleBase::Chebyshev<double > che (nche_dos);
461463 const int nk = GlobalC::kv.nks ;
462464 Stochastic_Iter& stoiter = ((hsolver::HSolverPW_SDFT*)phsol)->stoiter ;
463465 Stochastic_hchi& stohchi = stoiter.stohchi ;
464466 const int npwx = GlobalC::wf.npwx ;
465467
466- double * spolyv = new double [nche_dos];
467- ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos);
468+ double * spolyv = nullptr ;
469+ std::complex <double > *allorderchi = nullptr ;
470+ if (stoiter.method == 1 )
471+ {
472+ spolyv = new double [nche_dos];
473+ ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos);
474+ }
475+ else
476+ {
477+ spolyv = new double [nche_dos*nche_dos];
478+ ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos*nche_dos);
479+ int nchip_new = ceil ((double )this ->stowf .nchip_max / npart);
480+ allorderchi = new std::complex <double > [nchip_new * npwx * nche_dos];
481+ }
482+ cout<<" 1. TracepolyA:" <<endl;
468483 for (int ik = 0 ;ik < nk;ik++)
469484 {
485+ cout<<" ik: " <<ik+1 <<endl;
470486 if (nk > 1 )
471487 {
472488 this ->phami ->updateHk (ik);
473489 }
474490 stohchi.current_ik = ik;
475491 const int npw = GlobalC::kv.ngk [ik];
476- const int nchip = this ->stowf .nchip [ik];
492+ const int nchipk = this ->stowf .nchip [ik];
477493
478- complex <double > * pchi;
494+ std:: complex <double > * pchi;
479495 if (GlobalV::NBANDS > 0 )
480496 pchi = stowf.chiortho [ik].c ;
481497 else
482498 pchi = stowf.chi0 [ik].c ;
483- che.tracepolyA (&stohchi, &Stochastic_hchi::hchi_reciprocal, pchi, npw, npwx, nchip);
484- for (int i = 0 ; i < nche_dos ; ++i)
499+ if (stoiter.method == 1 )
485500 {
486- spolyv[i] += che.polytrace [i] * GlobalC::kv.wk [ik] / 2 ;
501+ che.tracepolyA (&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk);
502+ for (int i = 0 ; i < nche_dos ; ++i)
503+ {
504+ spolyv[i] += che.polytrace [i] * GlobalC::kv.wk [ik] / 2 ;
505+ }
506+ }
507+ else
508+ {
509+ int N = nche_dos;
510+ double kweight = GlobalC::kv.wk [ik] / 2 ;
511+ char trans = ' T' ;
512+ char normal = ' N' ;
513+ double one = 1 ;
514+ for (int ipart = 0 ; ipart < npart ; ++ipart)
515+ {
516+ int nchipk_new = nchipk / npart;
517+ int start_nchipk = ipart * nchipk_new + nchipk % npart;
518+ if (ipart < nchipk % npart)
519+ {
520+ nchipk_new++;
521+ start_nchipk = ipart * nchipk_new;
522+ }
523+ ModuleBase::GlobalFunc::ZEROS (allorderchi, nchipk_new * npwx * nche_dos);
524+ std::complex <double > *tmpchi = pchi + start_nchipk * npwx;
525+ che.calpolyvec_complex (&stohchi, &Stochastic_hchi::hchi_norm, tmpchi, allorderchi, npw, npwx, nchipk_new);
526+ double * vec_all= (double *) allorderchi;
527+ int LDA = npwx * nchipk_new * 2 ;
528+ int M = npwx * nchipk_new * 2 ;
529+ dgemm_ (&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
530+ }
487531 }
488532 }
533+ if (stoiter.method == 2 ) delete[] allorderchi;
534+
489535 string dosfile = GlobalV::global_out_dir+" DOS1_smearing.dat" ;
490536 ofstream ofsdos (dosfile.c_str ());
491537 int ndos = int ((emax-emin) / de)+1 ;
492538 double *dos = new double [ndos];
493539 ModuleBase::GlobalFunc::ZEROS (dos,ndos);
494540 stoiter.stofunc .sigma = sigmain / ModuleBase::Ry_to_eV;
495541 double sum = 0 ;
496- double error = 0 ;
542+ double maxerror = 0 ;
497543 ofsdos<<setw (8 )<<" ## E(eV) " <<setw (20 )<<" dos(eV^-1)" <<setw (20 )<<" sum" <<setw (20 )<<" Error(eV^-1)" <<endl;
544+ cout<<" 2. Dos:" <<endl;
545+ int n10 = ndos/10 ;
546+ int percent = 10 ;
498547 for (int ie = 0 ; ie < ndos; ++ie)
499548 {
500- stoiter.stofunc .targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
501- che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::ngauss);
502549 double KS_dos = 0 ;
503- double sto_dos = BlasConnector::dot (nche_dos,che.coef_real ,1 ,spolyv,1 );
550+ double sto_dos = 0 ;
551+ stoiter.stofunc .targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
552+ if (stoiter.method == 1 )
553+ {
554+ che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::ngauss);
555+ sto_dos = BlasConnector::dot (nche_dos,che.coef_real ,1 ,spolyv,1 );
556+ }
557+ else
558+ {
559+ che.calcoef_real (&stoiter.stofunc , &Sto_Func<double >::nroot_gauss);
560+ sto_dos = stoiter.vTMv (che.coef_real ,spolyv,nche_dos);
561+ }
504562 if (GlobalV::NBANDS > 0 )
505563 {
506564 for (int ik = 0 ; ik < nk; ++ik)
@@ -517,16 +575,35 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
517575 MPI_Allreduce (MPI_IN_PLACE, &KS_dos, 1 , MPI_DOUBLE, MPI_SUM , STO_WORLD);
518576 MPI_Allreduce (MPI_IN_PLACE, &sto_dos, 1 , MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
519577#endif
520- double tmpre = che.coef_real [nche_dos-1 ] * spolyv[nche_dos-1 ];
578+ double tmpre = 0 ;
579+ if (stoiter.method == 1 )
580+ {
581+ tmpre = che.coef_real [nche_dos-1 ] * spolyv[nche_dos-1 ];
582+ }
583+ else
584+ {
585+ const int norder = nche_dos;
586+ double last_coef = che.coef_real [norder-1 ];
587+ double last_spolyv = spolyv[norder*norder - 1 ];
588+ tmpre = last_coef *(BlasConnector::dot (norder,che.coef_real ,1 ,spolyv+norder*(norder-1 ),1 )
589+ + BlasConnector::dot (norder,che.coef_real ,1 ,spolyv+norder-1 ,norder)-last_coef*last_spolyv);
590+ }
521591#ifdef __MPI
522592 MPI_Allreduce (MPI_IN_PLACE, &tmpre, 1 , MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD);
523593#endif
524- if (error < tmpre) error = tmpre;
594+ if (maxerror < tmpre) maxerror = tmpre;
525595 dos[ie] = (KS_dos + sto_dos) / ModuleBase::Ry_to_eV;
526596 sum += dos[ie];
527- ofsdos <<setw (8 )<< emin + ie * de <<setw (20 )<<dos[ie]<<setw (20 )<<sum * de <<setw (20 ) <<error <<endl;
597+ ofsdos <<setw (8 )<< emin + ie * de <<setw (20 )<<dos[ie]<<setw (20 )<<sum * de <<setw (20 ) <<tmpre <<endl;
598+ if (ie%n10 == n10 -1 )
599+ {
600+ cout<<percent<<" %" <<" " ;
601+ percent+=10 ;
602+ }
528603 }
529- cout<<scientific<<" DOS max Chebyshev Error: " <<error<<endl;
604+ cout<<endl;
605+ cout<<" Finish DOS" <<endl;
606+ cout<<scientific<<" DOS max Chebyshev Error: " <<maxerror<<endl;
530607 delete[] dos;
531608 delete[] spolyv;
532609 return ;
0 commit comments