Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 83 additions & 41 deletions source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,7 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
{
ModuleBase::TITLE("Stochastic_Iter", "sum_stoband");
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");
int nrxx = wfc_basis->nrxx;
int npwx = wfc_basis->npwk_max;
const int npwx = wfc_basis->npwk_max;
const int norder = p_che->norder;

//---------------cal demet-----------------------
Expand Down Expand Up @@ -557,33 +556,53 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
#endif
pes->f_en.eband += sto_eband;
//---------------------cal rho-------------------------
double* sto_rho = new double[nrxx];
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");
}

double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz;
double tmprho, tmpne;
T outtem;
double sto_ne = 0;
ModuleBase::GlobalFunc::ZEROS(sto_rho, nrxx);
template <typename T, typename Device>
void Stochastic_Iter<T, Device>::cal_storho(Stochastic_WF<T, Device>& stowf,
elecstate::ElecStatePW<T, Device>* pes,
ModulePW::PW_Basis_K* wfc_basis)
{
ModuleBase::TITLE("Stochastic_Iter", "cal_storho");
ModuleBase::timer::tick("Stochastic_Iter", "cal_storho");
//---------------------cal rho-------------------------
const int nrxx = wfc_basis->nrxx;
const int npwx = wfc_basis->npwk_max;
const int nspin = PARAM.inp.nspin;

T* porter = nullptr;
resmem_complex_op()(this->ctx, porter, nrxx);
double out2;

double* ksrho = nullptr;
if (PARAM.inp.nbands > 0 && GlobalV::MY_STOGROUP == 0)
std::vector<double*> sto_rho(nspin);
for(int is = 0; is < nspin; ++is)
{
sto_rho[is] = pes->charge->rho[is];
}
std::vector<double> _tmprho;
if (PARAM.inp.nbands > 0)
{
ksrho = new double[nrxx];
ModuleBase::GlobalFunc::DCOPY(pes->charge->rho[0], ksrho, nrxx);
setmem_var_op()(this->ctx, pes->rho[0], 0, nrxx);
// ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx);
// If there are KS orbitals, we need to allocate another memory for sto_rho
_tmprho.resize(nrxx * nspin);
for(int is = 0; is < nspin; ++is)
{
sto_rho[is] = _tmprho.data() + is * nrxx;
}
}

// pes->rho is a device memory, and when using cpu and double, we donot need to allocate memory for pes->rho
if (PARAM.inp.device != "gpu" && PARAM.inp.precision != "single") {
pes->rho = reinterpret_cast<Real **>(sto_rho.data());
}
for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, pes->rho[is], 0, nrxx);
}
for (int ik = 0; ik < this->pkv->get_nks(); ++ik)
{
const int nchip_ik = nchip[ik];
int current_spin = 0;
if (PARAM.inp.nspin == 2)
if (nspin == 2)
{
current_spin = this->pkv->isk[ik];
}
Expand All @@ -602,27 +621,50 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
}
}
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single") {
for (int ii = 0; ii < PARAM.inp.nspin; ii++) {
castmem_var_d2h_op()(this->cpu_ctx, this->ctx, pes->charge->rho[ii], pes->rho[ii], nrxx);
for(int is = 0; is < nspin; ++is)
{
castmem_var_d2h_op()(this->cpu_ctx, this->ctx, sto_rho[is], pes->rho[is], nrxx);
}
}
else
{
// We need to set pes->rho back to the original value
pes->rho = reinterpret_cast<Real **>(pes->charge->rho);
}

delmem_complex_op()(this->ctx, porter);
#ifdef __MPI
// temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho
pes->charge->rho_mpi();
if(GlobalV::KPAR > 1)
{
for (int is = 0; is < nspin; ++is)
{
pes->charge->reduce_diff_pools(sto_rho[is]);
}
}
#endif
for (int ir = 0; ir < nrxx; ++ir)

double sto_ne = 0;
for(int is = 0; is < nspin; ++is)
{
tmprho = pes->charge->rho[0][ir] / GlobalC::ucell.omega;
sto_rho[ir] = tmprho;
sto_ne += tmprho;
#ifdef _OPENMP
#pragma omp parallel for reduction(+ : sto_ne)
#endif
for (int ir = 0; ir < nrxx; ++ir)
{
sto_rho[is][ir] /= GlobalC::ucell.omega;
sto_ne += sto_rho[is][ir];
}
}
sto_ne *= dr3;

sto_ne *= GlobalC::ucell.omega / wfc_basis->nxyz;

#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
MPI_Allreduce(MPI_IN_PLACE, sto_rho, nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
for(int is = 0; is < nspin; ++is)
{
MPI_Allreduce(MPI_IN_PLACE, sto_rho[is], nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD);
}
#endif
double factor = targetne / (KS_ne + sto_ne);
if (std::abs(factor - 1) > 1e-10)
Expand All @@ -635,32 +677,32 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
factor = 1;
}

if (GlobalV::MY_STOGROUP == 0)
for (int is = 0; is < 1; ++is)
{
if (PARAM.inp.nbands > 0)
{
ModuleBase::GlobalFunc::DCOPY(ksrho, pes->charge->rho[0], nrxx);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int ir = 0; ir < nrxx; ++ir)
{
pes->charge->rho[is][ir] += sto_rho[is][ir];
pes->charge->rho[is][ir] *= factor;
}
}
else
{
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[0], nrxx);
}
}

if (GlobalV::MY_STOGROUP == 0)
{
for (int is = 0; is < 1; ++is)
{
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int ir = 0; ir < nrxx; ++ir)
{
pes->charge->rho[is][ir] += sto_rho[ir];
pes->charge->rho[is][ir] *= factor;
}
}
}
delete[] sto_rho;
delete[] ksrho;
ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband");

ModuleBase::timer::tick("Stochastic_Iter", "cal_storho");
return;
}

Expand Down
53 changes: 53 additions & 0 deletions source/module_hamilt_pw/hamilt_stodft/sto_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,72 @@ class Stochastic_Iter
StoChe<Real, Device>& stoche,
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto);

/**
* @brief sum demet and eband energies for each k point and each band
*
* @param stowf stochastic wave function
* @param pes elecstate
* @param pHamilt hamiltonian
* @param wfc_basis wfc pw basis
*/
void sum_stoband(Stochastic_WF<T, Device>& stowf,
elecstate::ElecStatePW<T, Device>* pes,
hamilt::Hamilt<T, Device>* pHamilt,
ModulePW::PW_Basis_K* wfc_basis);

/**
* @brief calculate the density
*
* @param stowf stochastic wave function
* @param pes elecstate
* @param wfc_basis wfc pw basis
*/
void cal_storho(Stochastic_WF<T, Device>& stowf,
elecstate::ElecStatePW<T, Device>* pes,
ModulePW::PW_Basis_K* wfc_basis);

/**
* @brief calculate total number of electrons
*
* @param pes elecstate
* @return double
*/
double calne(elecstate::ElecState* pes);

/**
* @brief solve ne(mu) = ne_target and get chemical potential mu
*
* @param iter scf iteration index
* @param pes elecstate
*/
void itermu(const int iter, elecstate::ElecState* pes);

/**
* @brief orthogonalize stochastic wave functions with KS wave functions
*
* @param ik k point index
* @param psi KS wave functions
* @param stowf stochastic wave functions
*/
void orthog(const int& ik, psi::Psi<T, Device>& psi, Stochastic_WF<T, Device>& stowf);

/**
* @brief check emax and emin
*
* @param ik k point index
* @param istep ion step index
* @param iter scf iteration index
* @param stowf stochastic wave functions
*/
void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF<T, Device>& stowf);

/**
* @brief check precision of Chebyshev expansion
*
* @param ref reference value
* @param thr threshold
* @param info information
*/
void check_precision(const double ref, const double thr, const std::string info);

ModuleBase::Chebyshev<double, Device>* p_che = nullptr;
Expand Down
9 changes: 2 additions & 7 deletions source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,10 @@ void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD);
#endif
}
else
{
for (int is = 0; is < this->nspin; is++)
{
setmem_var_op()(this->ctx, pes_pw->rho[is], 0, pes_pw->charge->nrxx);
}
}

// calculate stochastic rho
stoiter.sum_stoband(stowf, pes_pw, pHamilt, wfc_basis);
stoiter.cal_storho(stowf, pes_pw, wfc_basis);

// will do rho symmetry and energy calculation in esolver
ModuleBase::timer::tick("HSolverPW_SDFT", "solve");
Expand Down
7 changes: 7 additions & 0 deletions source/module_hsolver/test/test_hsolver_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ void Stochastic_Iter<T, Device>::sum_stoband(Stochastic_WF<T, Device>& stowf,
return;
}

template <typename T, typename Device>
void Stochastic_Iter<T, Device>::cal_storho(Stochastic_WF<T, Device>& stowf,
elecstate::ElecStatePW<T, Device>* pes,
ModulePW::PW_Basis_K* wfc_basis)
{
}

Charge::Charge(){};
Charge::~Charge(){};

Expand Down
Loading