@@ -479,8 +479,7 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
479479{
480480 ModuleBase::TITLE (" Stochastic_Iter" , " sum_stoband" );
481481 ModuleBase::timer::tick (" Stochastic_Iter" , " sum_stoband" );
482- int nrxx = wfc_basis->nrxx ;
483- int npwx = wfc_basis->npwk_max ;
482+ const int npwx = wfc_basis->npwk_max ;
484483 const int norder = p_che->norder ;
485484
486485 // ---------------cal demet-----------------------
@@ -557,33 +556,53 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
557556 MPI_Allreduce (MPI_IN_PLACE, &sto_eband, 1 , MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
558557#endif
559558 pes->f_en .eband += sto_eband;
560- // ---------------------cal rho-------------------------
561- double * sto_rho = new double [nrxx];
559+ ModuleBase::timer::tick ( " Stochastic_Iter " , " sum_stoband " );
560+ }
562561
563- double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz ;
564- double tmprho, tmpne;
565- T outtem;
566- double sto_ne = 0 ;
567- ModuleBase::GlobalFunc::ZEROS (sto_rho, nrxx);
562+ template <typename T, typename Device>
563+ void Stochastic_Iter<T, Device>::cal_storho(Stochastic_WF<T, Device>& stowf,
564+ elecstate::ElecStatePW<T, Device>* pes,
565+ ModulePW::PW_Basis_K* wfc_basis)
566+ {
567+ ModuleBase::TITLE (" Stochastic_Iter" , " cal_storho" );
568+ ModuleBase::timer::tick (" Stochastic_Iter" , " cal_storho" );
569+ // ---------------------cal rho-------------------------
570+ const int nrxx = wfc_basis->nrxx ;
571+ const int npwx = wfc_basis->npwk_max ;
572+ const int nspin = PARAM.inp .nspin ;
568573
569574 T* porter = nullptr ;
570575 resmem_complex_op ()(this ->ctx , porter, nrxx);
571- double out2;
572576
573- double * ksrho = nullptr ;
574- if (PARAM.inp .nbands > 0 && GlobalV::MY_STOGROUP == 0 )
577+ std::vector<double *> sto_rho (nspin);
578+ for (int is = 0 ; is < nspin; ++is)
579+ {
580+ sto_rho[is] = pes->charge ->rho [is];
581+ }
582+ std::vector<double > _tmprho;
583+ if (PARAM.inp .nbands > 0 )
575584 {
576- ksrho = new double [nrxx];
577- ModuleBase::GlobalFunc::DCOPY (pes->charge ->rho [0 ], ksrho, nrxx);
578- setmem_var_op ()(this ->ctx , pes->rho [0 ], 0 , nrxx);
579- // ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx);
585+ // If there are KS orbitals, we need to allocate another memory for sto_rho
586+ _tmprho.resize (nrxx * nspin);
587+ for (int is = 0 ; is < nspin; ++is)
588+ {
589+ sto_rho[is] = _tmprho.data () + is * nrxx;
590+ }
580591 }
581592
593+ // pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho
594+ if (PARAM.inp .device != " gpu" && PARAM.inp .precision != " single" ) {
595+ pes->rho = reinterpret_cast <Real **>(sto_rho.data ());
596+ }
597+ for (int is = 0 ; is < nspin; is++)
598+ {
599+ setmem_var_op ()(this ->ctx , pes->rho [is], 0 , nrxx);
600+ }
582601 for (int ik = 0 ; ik < this ->pkv ->get_nks (); ++ik)
583602 {
584603 const int nchip_ik = nchip[ik];
585604 int current_spin = 0 ;
586- if (PARAM. inp . nspin == 2 )
605+ if (nspin == 2 )
587606 {
588607 current_spin = this ->pkv ->isk [ik];
589608 }
@@ -602,27 +621,50 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
602621 }
603622 }
604623 if (PARAM.inp .device == " gpu" || PARAM.inp .precision == " single" ) {
605- for (int ii = 0 ; ii < PARAM.inp .nspin ; ii++) {
606- castmem_var_d2h_op ()(this ->cpu_ctx , this ->ctx , pes->charge ->rho [ii], pes->rho [ii], nrxx);
624+ for (int is = 0 ; is < nspin; ++is)
625+ {
626+ castmem_var_d2h_op ()(this ->cpu_ctx , this ->ctx , sto_rho[is], pes->rho [is], nrxx);
607627 }
608628 }
629+ else
630+ {
631+ // We need to set pes->rho back to the original value
632+ pes->rho = reinterpret_cast <Real **>(pes->charge ->rho );
633+ }
634+
609635 delmem_complex_op ()(this ->ctx , porter);
610636#ifdef __MPI
611- // temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho
612- pes->charge ->rho_mpi ();
637+ if (GlobalV::KPAR > 1 )
638+ {
639+ for (int is = 0 ; is < nspin; ++is)
640+ {
641+ pes->charge ->reduce_diff_pools (sto_rho[is]);
642+ }
643+ }
613644#endif
614- for (int ir = 0 ; ir < nrxx; ++ir)
645+
646+ double sto_ne = 0 ;
647+ for (int is = 0 ; is < nspin; ++is)
615648 {
616- tmprho = pes->charge ->rho [0 ][ir] / GlobalC::ucell.omega ;
617- sto_rho[ir] = tmprho;
618- sto_ne += tmprho;
649+ #ifdef _OPENMP
650+ #pragma omp parallel for reduction(+ : sto_ne)
651+ #endif
652+ for (int ir = 0 ; ir < nrxx; ++ir)
653+ {
654+ sto_rho[is][ir] /= GlobalC::ucell.omega ;
655+ sto_ne += sto_rho[is][ir];
656+ }
619657 }
620- sto_ne *= dr3;
658+
659+ sto_ne *= GlobalC::ucell.omega / wfc_basis->nxyz ;
621660
622661#ifdef __MPI
623662 MPI_Allreduce (MPI_IN_PLACE, &sto_ne, 1 , MPI_DOUBLE, MPI_SUM, POOL_WORLD);
624663 MPI_Allreduce (MPI_IN_PLACE, &sto_ne, 1 , MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
625- MPI_Allreduce (MPI_IN_PLACE, sto_rho, nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
664+ for (int is = 0 ; is < nspin; ++is)
665+ {
666+ MPI_Allreduce (MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
667+ }
626668#endif
627669 double factor = targetne / (KS_ne + sto_ne);
628670 if (std::abs (factor - 1 ) > 1e-10 )
@@ -635,32 +677,32 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
635677 factor = 1 ;
636678 }
637679
638- if (GlobalV::MY_STOGROUP == 0 )
680+ for ( int is = 0 ; is < 1 ; ++is )
639681 {
640682 if (PARAM.inp .nbands > 0 )
641683 {
642- ModuleBase::GlobalFunc::DCOPY (ksrho, pes->charge ->rho [0 ], nrxx);
684+ #ifdef _OPENMP
685+ #pragma omp parallel for
686+ #endif
687+ for (int ir = 0 ; ir < nrxx; ++ir)
688+ {
689+ pes->charge ->rho [is][ir] += sto_rho[is][ir];
690+ pes->charge ->rho [is][ir] *= factor;
691+ }
643692 }
644693 else
645694 {
646- ModuleBase::GlobalFunc::ZEROS (pes->charge ->rho [0 ], nrxx);
647- }
648- }
649-
650- if (GlobalV::MY_STOGROUP == 0 )
651- {
652- for (int is = 0 ; is < 1 ; ++is)
653- {
695+ #ifdef _OPENMP
696+ #pragma omp parallel for
697+ #endif
654698 for (int ir = 0 ; ir < nrxx; ++ir)
655699 {
656- pes->charge ->rho [is][ir] += sto_rho[ir];
657700 pes->charge ->rho [is][ir] *= factor;
658701 }
659702 }
660703 }
661- delete[] sto_rho;
662- delete[] ksrho;
663- ModuleBase::timer::tick (" Stochastic_Iter" , " sum_stoband" );
704+
705+ ModuleBase::timer::tick (" Stochastic_Iter" , " cal_storho" );
664706 return ;
665707}
666708
0 commit comments