Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class PyDiagoDavSubspace
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
};

auto spsi_func = [this](const std::complex<double>* psi_in,
std::complex<double>* spsi_out,
const int ld_psi,
const int nvec) { syncmem_op()(spsi_out, psi_in, static_cast<size_t>(ld_psi * nvec)); };

obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
precond_vec,
nband,
Expand All @@ -145,7 +150,7 @@ class PyDiagoDavSubspace
nb2d
);

return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
return obj->diag(hpsi_func, spsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
}

private:
Expand All @@ -156,6 +161,10 @@ class PyDiagoDavSubspace
int nband;

std::unique_ptr<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>> obj;

base_device::DEVICE_CPU* ctx = {};
using syncmem_op = base_device::memory::
synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;
};

} // namespace py_hsolver
Expand Down
117 changes: 82 additions & 35 deletions source/source_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
resmem_complex_op()(this->hphi, this->nbase_x * this->dim, "DAV::hphi");
setmem_complex_op()(this->hphi, 0, this->nbase_x * this->dim);

// the product of S and psi in the reduced psi set
resmem_complex_op()(this->sphi, this->nbase_x * this->dim, "DAV::sphi");
setmem_complex_op()(this->sphi, 0, this->nbase_x * this->dim);

// Hamiltonian on the reduced psi set
resmem_complex_op()(this->hcc, this->nbase_x * this->nbase_x, "DAV::hcc");
setmem_complex_op()(this->hcc, 0, this->nbase_x * this->nbase_x);
Expand Down Expand Up @@ -96,6 +100,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
Expand Down Expand Up @@ -134,7 +139,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->notconv);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
// compute s*psi_in_iter
// sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
spsi_func(this->psi_in_iter, this->sphi, this->dim, this->notconv);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->sphi, this->hphi, this->hcc, this->scc);

this->diag_zhegvx(nbase, this->notconv, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);

Expand All @@ -152,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
dav_iter++;

this->cal_grad(hpsi_func,
spsi_func,
this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->sphi,
this->vcc,
unconv.data(),
&eigenvalue_iter);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
this->cal_elem(this->dim,
nbase,
this->notconv,
this->psi_in_iter,
this->sphi,
this->hphi,
this->hcc,
this->scc);

this->diag_zhegvx(nbase, this->n_band, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);

Expand Down Expand Up @@ -238,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
eigenvalue_in_hsolver,
this->psi_in_iter,
this->hphi,
this->sphi,
this->hcc,
this->scc,
this->vcc);
Expand All @@ -255,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

template <typename T, typename Device>
void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
const int& dim,
const int& nbase,
const int& notconv,
T* psi_iter,
T* hphi,
T* spsi,
T* vcc,
const int* unconv,
std::vector<Real>* eigenvalue_iter)
Expand Down Expand Up @@ -331,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
notconv,
nbase,
this->one,
psi_iter,
sphi,
this->dim,
vcc,
this->nbase_x,
Expand Down Expand Up @@ -396,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
// update hpsi[:, nbase:nbase+notconv]
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
spsi_func(psi_iter + nbase * dim, sphi + nbase * this->dim, this->dim, notconv);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand All @@ -406,6 +428,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
int& nbase,
const int& notconv,
const T* psi_iter,
const T* spsi,
const T* hphi,
T* hcc,
T* scc)
Expand All @@ -416,39 +439,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
ModuleBase::gemm_op_mt<T, Device>()
#else
ModuleBase::gemm_op<T, Device>()
#endif
('C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
psi_iter,
this->dim,
&hphi[nbase * this->dim],
this->dim,
this->zero,
&hcc[nbase * this->nbase_x],
this->nbase_x);
#endif
('C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
psi_iter,
this->dim,
&hphi[nbase * this->dim],
this->dim,
this->zero,
&hcc[nbase * this->nbase_x],
this->nbase_x);

#ifdef __DSP
ModuleBase::gemm_op_mt<T, Device>()
#else
ModuleBase::gemm_op<T, Device>()
#endif
('C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
psi_iter,
this->dim,
psi_iter + nbase * this->dim,
this->dim,
this->zero,
&scc[nbase * this->nbase_x],
this->nbase_x);
('C',
'N',
nbase + notconv,
notconv,
this->dim,
this->one,
psi_iter,
this->dim,
spsi + nbase * this->dim,
this->dim,
this->zero,
&scc[nbase * this->nbase_x],
this->nbase_x);

#ifdef __MPI
if (this->diag_comm.nproc > 1)
Expand Down Expand Up @@ -685,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
const Real* eigenvalue_in_hsolver,
// const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hp,
T* sp,
T* hc,
T* vc)
T* hphi,
T* sphi,
T* hcc,
T* scc,
T* vcc)
{
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");

Expand All @@ -714,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
// update hphi
syncmem_complex_op()(hphi, psi_iter + nband * this->dim, this->dim * nband);

#ifdef __DSP
ModuleBase::gemm_op_mt<T, Device>()
#else
ModuleBase::gemm_op<T, Device>()
#endif
('N',
'N',
this->dim,
nband,
nbase,
this->one,
this->sphi,
this->dim,
this->vcc,
this->nbase_x,
this->zero,
psi_iter + nband * this->dim,
this->dim);

// update sphi
syncmem_complex_op()(sphi, psi_iter + nband * this->dim, this->dim * nband);

nbase = nband;

// set hcc/scc/vcc to 0
Expand Down Expand Up @@ -776,6 +822,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,

template <typename T, typename Device>
int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in_hsolver,
Expand All @@ -791,7 +838,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
do
{

sum_iter += this->diag_once(hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
sum_iter += this->diag_once(hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);

++ntry;

Expand Down
17 changes: 16 additions & 1 deletion source/source_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Diago_DavSubspace
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;

int diag(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in,
Expand Down Expand Up @@ -81,6 +82,9 @@ class Diago_DavSubspace
/// the product of H and psi in the reduced basis set
T* hphi = nullptr;

/// the product of S and psi in the reduced basis set
T* sphi = nullptr;

/// Hamiltonian on the reduced basis
T* hcc = nullptr;

Expand All @@ -96,23 +100,33 @@ class Diago_DavSubspace
base_device::AbacusDevice_t device = {};

void cal_grad(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
const int& dim,
const int& nbase,
const int& notconv,
T* psi_iter,
T* hphi,
T* spsi,
T* vcc,
const int* unconv,
std::vector<Real>* eigenvalue_iter);

void cal_elem(const int& dim, int& nbase, const int& notconv, const T* psi_iter, const T* hphi, T* hcc, T* scc);
void cal_elem(const int& dim,
int& nbase,
const int& notconv,
const T* psi_iter,
const T* sphi,
const T* hphi,
T* hcc,
T* scc);

void refresh(const int& dim,
const int& nband,
int& nbase,
const Real* eigenvalue,
T* psi_iter,
T* hphi,
T* sphi,
T* hcc,
T* scc,
T* vcc);
Expand All @@ -134,6 +148,7 @@ class Diago_DavSubspace
T* vcc);

int diag_once(const HPsiFunc& hpsi_func,
const HPsiFunc& spsi_func,
T* psi_in,
const int psi_in_dmax,
Real* eigenvalue_in,
Expand Down
7 changes: 6 additions & 1 deletion source/source_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
};
bool scf = this->calculation_type == "nscf" ? false : true;

auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
};

Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_ngk()
Expand All @@ -393,7 +397,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
PARAM.inp.nb2d);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
dav_subspace
.diag(hpsi_func, spsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
}
else if (this->method == "dav")
{
Expand Down
9 changes: 2 additions & 7 deletions source/source_lcao/module_lr/hsolver_lrtd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,8 @@ namespace LR
PARAM.inp.diag_subspace,
PARAM.inp.nb2d);
std::vector<double> ethr_band(nband, diag_ethr);
hsolver::DiagoIterAssist<T>::avg_iter
+= static_cast<double>(dav_subspace.diag(
hpsi_func, psi,
dim,
eigenvalue.data(),
ethr_band,
false /*scf*/));
hsolver::DiagoIterAssist<T>::avg_iter += static_cast<double>(
dav_subspace.diag(hpsi_func, spsi_func, psi, dim, eigenvalue.data(), ethr_band, false /*scf*/));
}
else if (method == "cg")
{
Expand Down
Loading
Loading