Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions source/source_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
setmem_complex_op()(this->psi_in_iter, 0, this->nbase_x * this->dim);

// the product of H and psi in the reduced psi set
resmem_complex_op()(this->hphi, this->nbase_x * this->dim, "DAV::hphi");
setmem_complex_op()(this->hphi, 0, this->nbase_x * this->dim);
resmem_complex_op()(this->hpsi, this->nbase_x * this->dim, "DAV::hpsi");
setmem_complex_op()(this->hpsi, 0, this->nbase_x * this->dim);

// the product of S and psi in the reduced psi set
resmem_complex_op()(this->sphi, this->nbase_x * this->dim, "DAV::sphi");
setmem_complex_op()(this->sphi, 0, this->nbase_x * this->dim);
resmem_complex_op()(this->spsi, this->nbase_x * this->dim, "DAV::spsi");
setmem_complex_op()(this->spsi, 0, this->nbase_x * this->dim);

// Hamiltonian on the reduced psi set
resmem_complex_op()(this->hcc, this->nbase_x * this->nbase_x, "DAV::hcc");
Expand Down Expand Up @@ -87,7 +87,8 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
{
delmem_complex_op()(this->psi_in_iter);

delmem_complex_op()(this->hphi);
delmem_complex_op()(this->hpsi);
delmem_complex_op()(this->spsi);
delmem_complex_op()(this->hcc);
delmem_complex_op()(this->scc);
delmem_complex_op()(this->vcc);
Expand Down Expand Up @@ -137,14 +138,14 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// compute h*psi_in_iter
// NOTE: bands after the first n_band should yield zero
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->notconv);
// hpsi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
hpsi_func(this->psi_in_iter, this->hpsi, this->dim, this->notconv);

// compute s*psi_in_iter
// sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
spsi_func(this->psi_in_iter, this->sphi, this->dim, this->notconv);
// spsi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
spsi_func(this->psi_in_iter, this->spsi, this->dim, this->notconv);

this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->sphi, this->hphi, this->hcc, this->scc);
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->spsi, this->hpsi, this->hcc, this->scc);

this->diag_zhegvx(nbase, this->notconv, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);

Expand All @@ -167,8 +168,8 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
nbase,
this->notconv,
this->psi_in_iter,
this->hphi,
this->sphi,
this->hpsi,
this->spsi,
this->vcc,
unconv.data(),
&eigenvalue_iter);
Expand All @@ -177,8 +178,8 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
nbase,
this->notconv,
this->psi_in_iter,
this->sphi,
this->hphi,
this->spsi,
this->hpsi,
this->hcc,
this->scc);

Expand Down Expand Up @@ -251,8 +252,8 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
nbase,
eigenvalue_in_hsolver,
this->psi_in_iter,
this->hphi,
this->sphi,
this->hpsi,
this->spsi,
this->hcc,
this->scc,
this->vcc);
Expand All @@ -275,7 +276,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
const int& nbase,
const int& notconv,
T* psi_iter,
T* hphi,
T* hpsi,
T* spsi,
T* vcc,
const int* unconv,
Expand Down Expand Up @@ -303,7 +304,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
notconv,
nbase,
this->one,
hphi,
hpsi,
this->dim,
vcc,
this->nbase_x,
Expand Down Expand Up @@ -333,7 +334,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
notconv,
nbase,
this->one,
sphi,
spsi,
this->dim,
vcc,
this->nbase_x,
Expand Down Expand Up @@ -392,8 +393,8 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,

// update hpsi[:, nbase:nbase+notconv]
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
spsi_func(psi_iter + nbase * dim, sphi + nbase * this->dim, this->dim, notconv);
hpsi_func(psi_iter + nbase * dim, hpsi + nbase * this->dim, this->dim, notconv);
spsi_func(psi_iter + nbase * dim, spsi + nbase * this->dim, this->dim, notconv);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand All @@ -405,7 +406,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
const int& notconv,
const T* psi_iter,
const T* spsi,
const T* hphi,
const T* hpsi,
T* hcc,
T* scc)
{
Expand All @@ -424,7 +425,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
this->one,
psi_iter,
this->dim,
&hphi[nbase * this->dim],
&hpsi[nbase * this->dim],
this->dim,
this->zero,
&hcc[nbase * this->nbase_x],
Expand Down Expand Up @@ -659,8 +660,8 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
const Real* eigenvalue_in_hsolver,
// const psi::Psi<T, Device>& psi,
T* psi_iter,
T* hphi,
T* sphi,
T* hpsi,
T* spsi,
T* hcc,
T* scc,
T* vcc)
Expand All @@ -678,16 +679,16 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
nband,
nbase,
this->one,
this->hphi,
this->hpsi,
this->dim,
this->vcc,
this->nbase_x,
this->zero,
psi_iter + nband * this->dim,
this->dim);

// update hphi
syncmem_complex_op()(hphi, psi_iter + nband * this->dim, this->dim * nband);
// update hpsi
syncmem_complex_op()(hpsi, psi_iter + nband * this->dim, this->dim * nband);

#ifdef __DSP
ModuleBase::gemm_op_mt<T, Device>()
Expand All @@ -700,16 +701,16 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
nband,
nbase,
this->one,
this->sphi,
this->spsi,
this->dim,
this->vcc,
this->nbase_x,
this->zero,
psi_iter + nband * this->dim,
this->dim);

// update sphi
syncmem_complex_op()(sphi, psi_iter + nband * this->dim, this->dim * nband);
// update spsi
syncmem_complex_op()(spsi, psi_iter + nband * this->dim, this->dim * nband);

nbase = nband;

Expand Down
14 changes: 7 additions & 7 deletions source/source_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class Diago_DavSubspace
T* psi_in_iter = nullptr;

/// the product of H and psi in the reduced basis set
T* hphi = nullptr;
T* hpsi = nullptr;

/// the product of S and psi in the reduced basis set
T* sphi = nullptr;
T* spsi = nullptr;

/// Hamiltonian on the reduced basis
T* hcc = nullptr;
Expand All @@ -108,7 +108,7 @@ class Diago_DavSubspace
const int& nbase,
const int& notconv,
T* psi_iter,
T* hphi,
T* hpsi,
T* spsi,
T* vcc,
const int* unconv,
Expand All @@ -118,8 +118,8 @@ class Diago_DavSubspace
int& nbase,
const int& notconv,
const T* psi_iter,
const T* sphi,
const T* hphi,
const T* spsi,
const T* hpsi,
T* hcc,
T* scc);

Expand All @@ -128,8 +128,8 @@ class Diago_DavSubspace
int& nbase,
const Real* eigenvalue,
T* psi_iter,
T* hphi,
T* sphi,
T* hpsi,
T* spsi,
T* hcc,
T* scc,
T* vcc);
Expand Down
Loading