Skip to content

Commit 0041bfd

Browse files
committed
Fix: dav_subspace for uspp
1 parent 78c4946 commit 0041bfd

File tree

4 files changed

+51
-40
lines changed

4 files changed

+51
-40
lines changed

source/module_lr/hsolver_lrtd.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,8 @@ namespace LR
105105
PARAM.inp.diag_subspace,
106106
PARAM.inp.nb2d);
107107
std::vector<double> ethr_band(nband, diag_ethr);
108-
hsolver::DiagoIterAssist<T>::avg_iter
109-
+= static_cast<double>(dav_subspace.diag(
110-
hpsi_func, psi,
111-
dim,
112-
eigenvalue.data(),
113-
ethr_band,
114-
false /*scf*/));
108+
hsolver::DiagoIterAssist<T>::avg_iter += static_cast<double>(
109+
dav_subspace.diag(hpsi_func, spsi_func, psi, dim, eigenvalue.data(), ethr_band, false /*scf*/));
115110
}
116111
else if (method == "cg")
117112
{

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
9696

9797
template <typename T, typename Device>
9898
int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
99+
const HPsiFunc& spsi_func,
99100
T* psi_in,
100101
const int psi_in_dmax,
101102
Real* eigenvalue_in_hsolver,
@@ -134,7 +135,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
134135
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
135136
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->notconv);
136137

137-
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
138+
// compute s*psi_in_iter
139+
// sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
140+
spsi_func(this->psi_in_iter, this->sphi, this->dim, this->notconv);
141+
142+
this->cal_elem(this->dim, nbase, this->notconv, this->sphi, this->hphi, this->hcc, this->scc);
138143

139144
this->diag_zhegvx(nbase, this->notconv, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
140145

@@ -161,7 +166,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
161166
unconv.data(),
162167
&eigenvalue_iter);
163168

164-
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
169+
this->cal_elem(this->dim, nbase, this->notconv, this->sphi, this->hphi, this->hcc, this->scc);
165170

166171
this->diag_zhegvx(nbase, this->n_band, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
167172

@@ -405,7 +410,7 @@ template <typename T, typename Device>
405410
void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
406411
int& nbase,
407412
const int& notconv,
408-
const T* psi_iter,
413+
const T* spsi,
409414
const T* hphi,
410415
T* hcc,
411416
T* scc)
@@ -416,39 +421,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
416421
ModuleBase::gemm_op_mt<T, Device>()
417422
#else
418423
ModuleBase::gemm_op<T, Device>()
419-
#endif
420-
('C',
421-
'N',
422-
nbase + notconv,
423-
notconv,
424-
this->dim,
425-
this->one,
426-
psi_iter,
427-
this->dim,
428-
&hphi[nbase * this->dim],
429-
this->dim,
430-
this->zero,
431-
&hcc[nbase * this->nbase_x],
432-
this->nbase_x);
424+
#endif
425+
('C',
426+
'N',
427+
nbase + notconv,
428+
notconv,
429+
this->dim,
430+
this->one,
431+
spsi,
432+
this->dim,
433+
&hphi[nbase * this->dim],
434+
this->dim,
435+
this->zero,
436+
&hcc[nbase * this->nbase_x],
437+
this->nbase_x);
433438

434439
#ifdef __DSP
435440
ModuleBase::gemm_op_mt<T, Device>()
436441
#else
437442
ModuleBase::gemm_op<T, Device>()
438443
#endif
439-
('C',
440-
'N',
441-
nbase + notconv,
442-
notconv,
443-
this->dim,
444-
this->one,
445-
psi_iter,
446-
this->dim,
447-
psi_iter + nbase * this->dim,
448-
this->dim,
449-
this->zero,
450-
&scc[nbase * this->nbase_x],
451-
this->nbase_x);
444+
('C',
445+
'N',
446+
nbase + notconv,
447+
notconv,
448+
this->dim,
449+
this->one,
450+
spsi,
451+
this->dim,
452+
spsi + nbase * this->dim,
453+
this->dim,
454+
this->zero,
455+
&scc[nbase * this->nbase_x],
456+
this->nbase_x);
452457

453458
#ifdef __MPI
454459
if (this->diag_comm.nproc > 1)
@@ -776,6 +781,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
776781

777782
template <typename T, typename Device>
778783
int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
784+
const HPsiFunc& spsi_func,
779785
T* psi_in,
780786
const int psi_in_dmax,
781787
Real* eigenvalue_in_hsolver,
@@ -791,7 +797,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
791797
do
792798
{
793799

794-
sum_iter += this->diag_once(hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
800+
sum_iter += this->diag_once(hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
795801

796802
++ntry;
797803

source/source_hsolver/diago_dav_subspace.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Diago_DavSubspace
4141
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
4242

4343
int diag(const HPsiFunc& hpsi_func,
44+
const HPsiFunc& spsi_func,
4445
T* psi_in,
4546
const int psi_in_dmax,
4647
Real* eigenvalue_in,
@@ -81,6 +82,9 @@ class Diago_DavSubspace
8182
/// the product of H and psi in the reduced basis set
8283
T* hphi = nullptr;
8384

85+
/// the product of S and psi in the reduced basis set
86+
T* sphi = nullptr;
87+
8488
/// Hamiltonian on the reduced basis
8589
T* hcc = nullptr;
8690

@@ -105,7 +109,7 @@ class Diago_DavSubspace
105109
const int* unconv,
106110
std::vector<Real>* eigenvalue_iter);
107111

108-
void cal_elem(const int& dim, int& nbase, const int& notconv, const T* psi_iter, const T* hphi, T* hcc, T* scc);
112+
void cal_elem(const int& dim, int& nbase, const int& notconv, const T* sphi, const T* hphi, T* hcc, T* scc);
109113

110114
void refresh(const int& dim,
111115
const int& nband,
@@ -134,6 +138,7 @@ class Diago_DavSubspace
134138
T* vcc);
135139

136140
int diag_once(const HPsiFunc& hpsi_func,
141+
const HPsiFunc& spsi_func,
137142
T* psi_in,
138143
const int psi_in_dmax,
139144
Real* eigenvalue_in,

source/source_hsolver/hsolver_pw.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
380380
};
381381
bool scf = this->calculation_type == "nscf" ? false : true;
382382

383+
auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
384+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
385+
};
386+
383387
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
384388
psi.get_nbands(),
385389
psi.get_k_first() ? psi.get_current_ngk()
@@ -393,7 +397,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
393397
PARAM.inp.nb2d);
394398

395399
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
396-
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
400+
dav_subspace
401+
.diag(hpsi_func, spsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
397402
}
398403
else if (this->method == "dav")
399404
{

0 commit comments

Comments
 (0)