Skip to content

Commit cab5894

Browse files
authored
Refactor: split sum_stoband to sum_stoband and cal_storho (#5600)
1 parent 6a98a1a commit cab5894

File tree

4 files changed

+145
-48
lines changed

4 files changed

+145
-48
lines changed

source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,7 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
479479
{
480480
ModuleBase::TITLE("Stochastic_Iter", "sum_stoband");
481481
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");
482-
int nrxx = wfc_basis->nrxx;
483-
int npwx = wfc_basis->npwk_max;
482+
const int npwx = wfc_basis->npwk_max;
484483
const int norder = p_che->norder;
485484

486485
//---------------cal demet-----------------------
@@ -557,33 +556,53 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
557556
MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
558557
#endif
559558
pes->f_en.eband += sto_eband;
560-
//---------------------cal rho-------------------------
561-
double* sto_rho = new double[nrxx];
559+
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");
560+
}
562561

563-
double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz;
564-
double tmprho, tmpne;
565-
T outtem;
566-
double sto_ne = 0;
567-
ModuleBase::GlobalFunc::ZEROS(sto_rho, nrxx);
562+
template <typename T, typename Device>
563+
void Stochastic_Iter<T, Device>::cal_storho(Stochastic_WF<T, Device>& stowf,
564+
elecstate::ElecStatePW<T, Device>* pes,
565+
ModulePW::PW_Basis_K* wfc_basis)
566+
{
567+
ModuleBase::TITLE("Stochastic_Iter", "cal_storho");
568+
ModuleBase::timer::tick("Stochastic_Iter", "cal_storho");
569+
//---------------------cal rho-------------------------
570+
const int nrxx = wfc_basis->nrxx;
571+
const int npwx = wfc_basis->npwk_max;
572+
const int nspin = PARAM.inp.nspin;
568573

569574
T* porter = nullptr;
570575
resmem_complex_op()(this->ctx, porter, nrxx);
571-
double out2;
572576

573-
double* ksrho = nullptr;
574-
if (PARAM.inp.nbands > 0 && GlobalV::MY_STOGROUP == 0)
577+
std::vector<double*> sto_rho(nspin);
578+
for(int is = 0; is < nspin; ++is)
579+
{
580+
sto_rho[is] = pes->charge->rho[is];
581+
}
582+
std::vector<double> _tmprho;
583+
if (PARAM.inp.nbands > 0)
575584
{
576-
ksrho = new double[nrxx];
577-
ModuleBase::GlobalFunc::DCOPY(pes->charge->rho[0], ksrho, nrxx);
578-
setmem_var_op()(this->ctx, pes->rho[0], 0, nrxx);
579-
// ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx);
585+
// If there are KS orbitals, we need to allocate another memory for sto_rho
586+
_tmprho.resize(nrxx * nspin);
587+
for(int is = 0; is < nspin; ++is)
588+
{
589+
sto_rho[is] = _tmprho.data() + is * nrxx;
590+
}
580591
}
581592

593+
// pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho
594+
if (PARAM.inp.device != "gpu" && PARAM.inp.precision != "single") {
595+
pes->rho = reinterpret_cast<Real **>(sto_rho.data());
596+
}
597+
for (int is = 0; is < nspin; is++)
598+
{
599+
setmem_var_op()(this->ctx, pes->rho[is], 0, nrxx);
600+
}
582601
for (int ik = 0; ik < this->pkv->get_nks(); ++ik)
583602
{
584603
const int nchip_ik = nchip[ik];
585604
int current_spin = 0;
586-
if (PARAM.inp.nspin == 2)
605+
if (nspin == 2)
587606
{
588607
current_spin = this->pkv->isk[ik];
589608
}
@@ -602,27 +621,50 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
602621
}
603622
}
604623
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
605-
for (int ii = 0; ii < PARAM.inp.nspin; ii++) {
606-
castmem_var_d2h_op()(this->cpu_ctx, this->ctx, pes->charge->rho[ii], pes->rho[ii], nrxx);
624+
for(int is = 0; is < nspin; ++is)
625+
{
626+
castmem_var_d2h_op()(this->cpu_ctx, this->ctx, sto_rho[is], pes->rho[is], nrxx);
607627
}
608628
}
629+
else
630+
{
631+
// We need to set pes->rho back to the original value
632+
pes->rho = reinterpret_cast<Real **>(pes->charge->rho);
633+
}
634+
609635
delmem_complex_op()(this->ctx, porter);
610636
#ifdef __MPI
611-
// temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho
612-
pes->charge->rho_mpi();
637+
if(GlobalV::KPAR > 1)
638+
{
639+
for (int is = 0; is < nspin; ++is)
640+
{
641+
pes->charge->reduce_diff_pools(sto_rho[is]);
642+
}
643+
}
613644
#endif
614-
for (int ir = 0; ir < nrxx; ++ir)
645+
646+
double sto_ne = 0;
647+
for(int is = 0; is < nspin; ++is)
615648
{
616-
tmprho = pes->charge->rho[0][ir] / GlobalC::ucell.omega;
617-
sto_rho[ir] = tmprho;
618-
sto_ne += tmprho;
649+
#ifdef _OPENMP
650+
#pragma omp parallel for reduction(+ : sto_ne)
651+
#endif
652+
for (int ir = 0; ir < nrxx; ++ir)
653+
{
654+
sto_rho[is][ir] /= GlobalC::ucell.omega;
655+
sto_ne += sto_rho[is][ir];
656+
}
619657
}
620-
sto_ne *= dr3;
658+
659+
sto_ne *= GlobalC::ucell.omega / wfc_basis->nxyz;
621660

622661
#ifdef __MPI
623662
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
624663
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
625-
MPI_Allreduce(MPI_IN_PLACE, sto_rho, nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
664+
for(int is = 0; is < nspin; ++is)
665+
{
666+
MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
667+
}
626668
#endif
627669
double factor = targetne / (KS_ne + sto_ne);
628670
if (std::abs(factor - 1) > 1e-10)
@@ -635,32 +677,32 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
635677
factor = 1;
636678
}
637679

638-
if (GlobalV::MY_STOGROUP == 0)
680+
for (int is = 0; is < 1; ++is)
639681
{
640682
if (PARAM.inp.nbands > 0)
641683
{
642-
ModuleBase::GlobalFunc::DCOPY(ksrho, pes->charge->rho[0], nrxx);
684+
#ifdef _OPENMP
685+
#pragma omp parallel for
686+
#endif
687+
for (int ir = 0; ir < nrxx; ++ir)
688+
{
689+
pes->charge->rho[is][ir] += sto_rho[is][ir];
690+
pes->charge->rho[is][ir] *= factor;
691+
}
643692
}
644693
else
645694
{
646-
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx);
647-
}
648-
}
649-
650-
if (GlobalV::MY_STOGROUP == 0)
651-
{
652-
for (int is = 0; is < 1; ++is)
653-
{
695+
#ifdef _OPENMP
696+
#pragma omp parallel for
697+
#endif
654698
for (int ir = 0; ir < nrxx; ++ir)
655699
{
656-
pes->charge->rho[is][ir] += sto_rho[ir];
657700
pes->charge->rho[is][ir] *= factor;
658701
}
659702
}
660703
}
661-
delete[] sto_rho;
662-
delete[] ksrho;
663-
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");
704+
705+
ModuleBase::timer::tick("Stochastic_Iter", "cal_storho");
664706
return;
665707
}
666708

source/module_hamilt_pw/hamilt_stodft/sto_iter.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,72 @@ class Stochastic_Iter
4444
StoChe<Real, Device>& stoche,
4545
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto);
4646

47+
/**
48+
* @brief sum demet and eband energies for each k point and each band
49+
*
50+
* @param stowf stochastic wave function
51+
* @param pes elecstate
52+
* @param pHamilt hamiltonian
53+
* @param wfc_basis wfc pw basis
54+
*/
4755
void sum_stoband(Stochastic_WF<T, Device>& stowf,
4856
elecstate::ElecStatePW<T, Device>* pes,
4957
hamilt::Hamilt<T, Device>* pHamilt,
5058
ModulePW::PW_Basis_K* wfc_basis);
5159

60+
/**
61+
* @brief calculate the density
62+
*
63+
* @param stowf stochastic wave function
64+
* @param pes elecstate
65+
* @param wfc_basis wfc pw basis
66+
*/
67+
void cal_storho(Stochastic_WF<T, Device>& stowf,
68+
elecstate::ElecStatePW<T, Device>* pes,
69+
ModulePW::PW_Basis_K* wfc_basis);
70+
71+
/**
72+
* @brief calculate total number of electrons
73+
*
74+
* @param pes elecstate
75+
* @return double
76+
*/
5277
double calne(elecstate::ElecState* pes);
5378

79+
/**
80+
* @brief solve ne(mu) = ne_target and get chemical potential mu
81+
*
82+
* @param iter scf iteration index
83+
* @param pes elecstate
84+
*/
5485
void itermu(const int iter, elecstate::ElecState* pes);
5586

87+
/**
88+
* @brief orthogonalize stochastic wave functions with KS wave functions
89+
*
90+
* @param ik k point index
91+
* @param psi KS wave functions
92+
* @param stowf stochastic wave functions
93+
*/
5694
void orthog(const int& ik, psi::Psi<T, Device>& psi, Stochastic_WF<T, Device>& stowf);
5795

96+
/**
97+
* @brief check emax and emin
98+
*
99+
* @param ik k point index
100+
* @param istep ion step index
101+
* @param iter scf iteration index
102+
* @param stowf stochastic wave functions
103+
*/
58104
void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF<T, Device>& stowf);
59105

106+
/**
107+
* @brief check precision of Chebyshev expansion
108+
*
109+
* @param ref reference value
110+
* @param thr threshold
111+
* @param info information
112+
*/
60113
void check_precision(const double ref, const double thr, const std::string info);
61114

62115
ModuleBase::Chebyshev<double, Device>* p_che = nullptr;

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,10 @@ void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
128128
MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD);
129129
#endif
130130
}
131-
else
132-
{
133-
for (int is = 0; is < this->nspin; is++)
134-
{
135-
setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx);
136-
}
137-
}
131+
138132
// calculate stochastic rho
139133
stoiter.sum_stoband(stowf, pes_pw, pHamilt, wfc_basis);
134+
stoiter.cal_storho(stowf, pes_pw, wfc_basis);
140135

141136
// will do rho symmetry and energy calculation in esolver
142137
ModuleBase::timer::tick("HSolverPW_SDFT", "solve");

source/module_hsolver/test/test_hsolver_sdft.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
151151
return;
152152
}
153153

154+
template <typename T, typename Device>
155+
void Stochastic_Iter<T, Device>::cal_storho(Stochastic_WF<T, Device>& stowf,
156+
elecstate::ElecStatePW<T, Device>* pes,
157+
ModulePW::PW_Basis_K* wfc_basis)
158+
{
159+
}
160+
154161
Charge::Charge(){};
155162
Charge::~Charge(){};
156163

0 commit comments

Comments
 (0)