Skip to content

Commit 2e3baa1

Browse files
committed
Fix: refresh sphi in dav_subspace
1 parent 0041bfd commit 2e3baa1

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
5454
resmem_complex_op()(this->hphi, this->nbase_x * this->dim, "DAV::hphi");
5555
setmem_complex_op()(this->hphi, 0, this->nbase_x * this->dim);
5656

57+
// the product of S and psi in the reduced psi set
58+
resmem_complex_op()(this->sphi, this->nbase_x * this->dim, "DAV::sphi");
59+
setmem_complex_op()(this->sphi, 0, this->nbase_x * this->dim);
60+
5761
// Hamiltonian on the reduced psi set
5862
resmem_complex_op()(this->hcc, this->nbase_x * this->nbase_x, "DAV::hcc");
5963
setmem_complex_op()(this->hcc, 0, this->nbase_x * this->nbase_x);
@@ -139,7 +143,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
139143
// sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
140144
spsi_func(this->psi_in_iter, this->sphi, this->dim, this->notconv);
141145

142-
this->cal_elem(this->dim, nbase, this->notconv, this->sphi, this->hphi, this->hcc, this->scc);
146+
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->sphi, this->hphi, this->hcc, this->scc);
143147

144148
this->diag_zhegvx(nbase, this->notconv, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
145149

@@ -157,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
157161
dav_iter++;
158162

159163
this->cal_grad(hpsi_func,
164+
spsi_func,
160165
this->dim,
161166
nbase,
162167
this->notconv,
163168
this->psi_in_iter,
164169
this->hphi,
170+
this->sphi,
165171
this->vcc,
166172
unconv.data(),
167173
&eigenvalue_iter);
168174

169-
this->cal_elem(this->dim, nbase, this->notconv, this->sphi, this->hphi, this->hcc, this->scc);
175+
this->cal_elem(this->dim,
176+
nbase,
177+
this->notconv,
178+
this->psi_in_iter,
179+
this->sphi,
180+
this->hphi,
181+
this->hcc,
182+
this->scc);
170183

171184
this->diag_zhegvx(nbase, this->n_band, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
172185

@@ -243,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
243256
eigenvalue_in_hsolver,
244257
this->psi_in_iter,
245258
this->hphi,
259+
this->sphi,
246260
this->hcc,
247261
this->scc,
248262
this->vcc);
@@ -260,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
260274

261275
template <typename T, typename Device>
262276
void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
277+
const HPsiFunc& spsi_func,
263278
const int& dim,
264279
const int& nbase,
265280
const int& notconv,
266281
T* psi_iter,
267282
T* hphi,
283+
T* spsi,
268284
T* vcc,
269285
const int* unconv,
270286
std::vector<Real>* eigenvalue_iter)
@@ -336,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
336352
notconv,
337353
nbase,
338354
this->one,
339-
psi_iter,
355+
sphi,
340356
this->dim,
341357
vcc,
342358
this->nbase_x,
@@ -401,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
401417
// update hpsi[:, nbase:nbase+notconv]
402418
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
403419
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
420+
spsi_func(psi_iter + nbase * dim, sphi + nbase * this->dim, this->dim, notconv);
404421

405422
ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
406423
return;
@@ -410,6 +427,7 @@ template <typename T, typename Device>
410427
void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
411428
int& nbase,
412429
const int& notconv,
430+
const T* psi_iter,
413431
const T* spsi,
414432
const T* hphi,
415433
T* hcc,
@@ -428,7 +446,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
428446
notconv,
429447
this->dim,
430448
this->one,
431-
spsi,
449+
psi_iter,
432450
this->dim,
433451
&hphi[nbase * this->dim],
434452
this->dim,
@@ -447,7 +465,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
447465
notconv,
448466
this->dim,
449467
this->one,
450-
spsi,
468+
psi_iter,
451469
this->dim,
452470
spsi + nbase * this->dim,
453471
this->dim,
@@ -690,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
690708
const Real* eigenvalue_in_hsolver,
691709
// const psi::Psi<T, Device>& psi,
692710
T* psi_iter,
693-
T* hp,
694-
T* sp,
695-
T* hc,
696-
T* vc)
711+
T* hphi,
712+
T* sphi,
713+
T* hcc,
714+
T* scc,
715+
T* vcc)
697716
{
698717
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");
699718

@@ -719,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
719738
// update hphi
720739
syncmem_complex_op()(hphi, psi_iter + nband * this->dim, this->dim * nband);
721740

741+
#ifdef __DSP
742+
ModuleBase::gemm_op_mt<T, Device>()
743+
#else
744+
ModuleBase::gemm_op<T, Device>()
745+
#endif
746+
('N',
747+
'N',
748+
this->dim,
749+
nband,
750+
nbase,
751+
this->one,
752+
this->sphi,
753+
this->dim,
754+
this->vcc,
755+
this->nbase_x,
756+
this->zero,
757+
psi_iter + nband * this->dim,
758+
this->dim);
759+
760+
// update sphi
761+
syncmem_complex_op()(sphi, psi_iter + nband * this->dim, this->dim * nband);
762+
722763
nbase = nband;
723764

724765
// set hcc/scc/vcc to 0

source/source_hsolver/diago_dav_subspace.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,33 @@ class Diago_DavSubspace
100100
base_device::AbacusDevice_t device = {};
101101

102102
void cal_grad(const HPsiFunc& hpsi_func,
103+
const HPsiFunc& spsi_func,
103104
const int& dim,
104105
const int& nbase,
105106
const int& notconv,
106107
T* psi_iter,
107108
T* hphi,
109+
T* spsi,
108110
T* vcc,
109111
const int* unconv,
110112
std::vector<Real>* eigenvalue_iter);
111113

112-
void cal_elem(const int& dim, int& nbase, const int& notconv, const T* sphi, const T* hphi, T* hcc, T* scc);
114+
void cal_elem(const int& dim,
115+
int& nbase,
116+
const int& notconv,
117+
const T* psi_iter,
118+
const T* sphi,
119+
const T* hphi,
120+
T* hcc,
121+
T* scc);
113122

114123
void refresh(const int& dim,
115124
const int& nband,
116125
int& nbase,
117126
const Real* eigenvalue,
118127
T* psi_iter,
119128
T* hphi,
129+
T* sphi,
120130
T* hcc,
121131
T* scc,
122132
T* vcc);

0 commit comments

Comments
 (0)