@@ -454,7 +454,7 @@ void ESolver_SDFT_PW::sKG(const int nche_KG, const double fwhmin, const double w
454454 ModuleBase::timer::tick (this ->classname ," sKG" );
455455}
456456
457- void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de)
457+ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const double emin, const double emax, const double de, const int npart )
458458{
459459 cout<<" =========================" <<endl;
460460 cout<<" ###Calculating Dos....###" <<endl;
@@ -476,7 +476,8 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
476476 {
477477 spolyv = new double [nche_dos*nche_dos];
478478 ModuleBase::GlobalFunc::ZEROS (spolyv, nche_dos*nche_dos);
479- allorderchi = new std::complex <double > [this ->stowf .nchip_max * npwx * nche_dos];
479+ int nchip_new = ceil ((double )this ->stowf .nchip_max / npart);
480+ allorderchi = new std::complex <double > [nchip_new * npwx * nche_dos];
480481 }
481482 cout<<" 1. TracepolyA:" <<endl;
482483 for (int ik = 0 ;ik < nk;ik++)
@@ -490,7 +491,7 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
490491 const int npw = GlobalC::kv.ngk [ik];
491492 const int nchipk = this ->stowf .nchip [ik];
492493
493- complex <double > * pchi;
494+ std:: complex <double > * pchi;
494495 if (GlobalV::NBANDS > 0 )
495496 pchi = stowf.chiortho [ik].c ;
496497 else
@@ -505,17 +506,28 @@ void ESolver_SDFT_PW:: caldos( const int nche_dos, const double sigmain, const d
505506 }
506507 else
507508 {
508- ModuleBase::GlobalFunc::ZEROS (allorderchi, this ->stowf .nchip_max * npwx * nche_dos);
509- che.calpolyvec_complex (&stohchi, &Stochastic_hchi::hchi_norm, pchi, allorderchi, npw, npwx, nchipk);
510- double * vec_all= (double *) allorderchi;
509+ int N = nche_dos;
510+ double kweight = GlobalC::kv.wk [ik] / 2 ;
511511 char trans = ' T' ;
512512 char normal = ' N' ;
513513 double one = 1 ;
514- int LDA = npwx * nchipk * 2 ;
515- int M = npwx * nchipk * 2 ;
516- int N = nche_dos;
517- double kweight = GlobalC::kv.wk [ik] / 2 ;
518- dgemm_ (&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
514+ for (int ipart = 0 ; ipart < npart ; ++ipart)
515+ {
516+ int nchipk_new = nchipk / npart;
517+ int start_nchipk = ipart * nchipk_new + nchipk % npart;
518+ if (ipart < nchipk % npart)
519+ {
520+ nchipk_new++;
521+ start_nchipk = ipart * nchipk_new;
522+ }
523+ ModuleBase::GlobalFunc::ZEROS (allorderchi, nchipk_new * npwx * nche_dos);
524+ std::complex <double > *tmpchi = pchi + start_nchipk * npwx;
525+ che.calpolyvec_complex (&stohchi, &Stochastic_hchi::hchi_norm, tmpchi, allorderchi, npw, npwx, nchipk_new);
526+ double * vec_all= (double *) allorderchi;
527+ int LDA = npwx * nchipk_new * 2 ;
528+ int M = npwx * nchipk_new * 2 ;
529+ dgemm_ (&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N);
530+ }
519531 }
520532 }
521533 if (stoiter.method == 2 ) delete[] allorderchi;
0 commit comments