Skip to content

Commit 2f59a6f

Browse files
authored
Fix: dav_subspace for uspp (#6428)
* Feature: plot sts_line_ave * Fix: dav_subspace for uspp * Fix: refresh sphi in dav_subspace * Fix: pyabacus interface * add syncmem_op for PyDiagoDavSubspace
1 parent acaa1fc commit 2f59a6f

File tree

6 files changed

+142
-47
lines changed

6 files changed

+142
-47
lines changed

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ class PyDiagoDavSubspace
132132
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
133133
};
134134

135+
auto spsi_func = [this](const std::complex<double>* psi_in,
136+
std::complex<double>* spsi_out,
137+
const int ld_psi,
138+
const int nvec) { syncmem_op()(spsi_out, psi_in, static_cast<size_t>(ld_psi * nvec)); };
139+
135140
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
136141
precond_vec,
137142
nband,
@@ -145,7 +150,7 @@ class PyDiagoDavSubspace
145150
nb2d
146151
);
147152

148-
return obj->diag(hpsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
153+
return obj->diag(hpsi_func, spsi_func, psi, nbasis, eigenvalue, diag_ethr, scf_type);
149154
}
150155

151156
private:
@@ -156,6 +161,10 @@ class PyDiagoDavSubspace
156161
int nband;
157162

158163
std::unique_ptr<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>> obj;
164+
165+
base_device::DEVICE_CPU* ctx = {};
166+
using syncmem_op = base_device::memory::
167+
synchronize_memory_op<std::complex<double>, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;
159168
};
160169

161170
} // namespace py_hsolver

source/source_hsolver/diago_dav_subspace.cpp

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
5454
resmem_complex_op()(this->hphi, this->nbase_x * this->dim, "DAV::hphi");
5555
setmem_complex_op()(this->hphi, 0, this->nbase_x * this->dim);
5656

57+
// the product of S and psi in the reduced psi set
58+
resmem_complex_op()(this->sphi, this->nbase_x * this->dim, "DAV::sphi");
59+
setmem_complex_op()(this->sphi, 0, this->nbase_x * this->dim);
60+
5761
// Hamiltonian on the reduced psi set
5862
resmem_complex_op()(this->hcc, this->nbase_x * this->nbase_x, "DAV::hcc");
5963
setmem_complex_op()(this->hcc, 0, this->nbase_x * this->nbase_x);
@@ -96,6 +100,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
96100

97101
template <typename T, typename Device>
98102
int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
103+
const HPsiFunc& spsi_func,
99104
T* psi_in,
100105
const int psi_in_dmax,
101106
Real* eigenvalue_in_hsolver,
@@ -134,7 +139,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
134139
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
135140
hpsi_func(this->psi_in_iter, this->hphi, this->dim, this->notconv);
136141

137-
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
142+
// compute s*psi_in_iter
143+
// sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
144+
spsi_func(this->psi_in_iter, this->sphi, this->dim, this->notconv);
145+
146+
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->sphi, this->hphi, this->hcc, this->scc);
138147

139148
this->diag_zhegvx(nbase, this->notconv, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
140149

@@ -152,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
152161
dav_iter++;
153162

154163
this->cal_grad(hpsi_func,
164+
spsi_func,
155165
this->dim,
156166
nbase,
157167
this->notconv,
158168
this->psi_in_iter,
159169
this->hphi,
170+
this->sphi,
160171
this->vcc,
161172
unconv.data(),
162173
&eigenvalue_iter);
163174

164-
this->cal_elem(this->dim, nbase, this->notconv, this->psi_in_iter, this->hphi, this->hcc, this->scc);
175+
this->cal_elem(this->dim,
176+
nbase,
177+
this->notconv,
178+
this->psi_in_iter,
179+
this->sphi,
180+
this->hphi,
181+
this->hcc,
182+
this->scc);
165183

166184
this->diag_zhegvx(nbase, this->n_band, this->hcc, this->scc, this->nbase_x, &eigenvalue_iter, this->vcc);
167185

@@ -238,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
238256
eigenvalue_in_hsolver,
239257
this->psi_in_iter,
240258
this->hphi,
259+
this->sphi,
241260
this->hcc,
242261
this->scc,
243262
this->vcc);
@@ -255,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
255274

256275
template <typename T, typename Device>
257276
void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
277+
const HPsiFunc& spsi_func,
258278
const int& dim,
259279
const int& nbase,
260280
const int& notconv,
261281
T* psi_iter,
262282
T* hphi,
283+
T* spsi,
263284
T* vcc,
264285
const int* unconv,
265286
std::vector<Real>* eigenvalue_iter)
@@ -331,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
331352
notconv,
332353
nbase,
333354
this->one,
334-
psi_iter,
355+
sphi,
335356
this->dim,
336357
vcc,
337358
this->nbase_x,
@@ -396,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
396417
// update hpsi[:, nbase:nbase+notconv]
397418
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
398419
hpsi_func(psi_iter + nbase * dim, hphi + nbase * this->dim, this->dim, notconv);
420+
spsi_func(psi_iter + nbase * dim, sphi + nbase * this->dim, this->dim, notconv);
399421

400422
ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
401423
return;
@@ -406,6 +428,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
406428
int& nbase,
407429
const int& notconv,
408430
const T* psi_iter,
431+
const T* spsi,
409432
const T* hphi,
410433
T* hcc,
411434
T* scc)
@@ -416,39 +439,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
416439
ModuleBase::gemm_op_mt<T, Device>()
417440
#else
418441
ModuleBase::gemm_op<T, Device>()
419-
#endif
420-
('C',
421-
'N',
422-
nbase + notconv,
423-
notconv,
424-
this->dim,
425-
this->one,
426-
psi_iter,
427-
this->dim,
428-
&hphi[nbase * this->dim],
429-
this->dim,
430-
this->zero,
431-
&hcc[nbase * this->nbase_x],
432-
this->nbase_x);
442+
#endif
443+
('C',
444+
'N',
445+
nbase + notconv,
446+
notconv,
447+
this->dim,
448+
this->one,
449+
psi_iter,
450+
this->dim,
451+
&hphi[nbase * this->dim],
452+
this->dim,
453+
this->zero,
454+
&hcc[nbase * this->nbase_x],
455+
this->nbase_x);
433456

434457
#ifdef __DSP
435458
ModuleBase::gemm_op_mt<T, Device>()
436459
#else
437460
ModuleBase::gemm_op<T, Device>()
438461
#endif
439-
('C',
440-
'N',
441-
nbase + notconv,
442-
notconv,
443-
this->dim,
444-
this->one,
445-
psi_iter,
446-
this->dim,
447-
psi_iter + nbase * this->dim,
448-
this->dim,
449-
this->zero,
450-
&scc[nbase * this->nbase_x],
451-
this->nbase_x);
462+
('C',
463+
'N',
464+
nbase + notconv,
465+
notconv,
466+
this->dim,
467+
this->one,
468+
psi_iter,
469+
this->dim,
470+
spsi + nbase * this->dim,
471+
this->dim,
472+
this->zero,
473+
&scc[nbase * this->nbase_x],
474+
this->nbase_x);
452475

453476
#ifdef __MPI
454477
if (this->diag_comm.nproc > 1)
@@ -685,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
685708
const Real* eigenvalue_in_hsolver,
686709
// const psi::Psi<T, Device>& psi,
687710
T* psi_iter,
688-
T* hp,
689-
T* sp,
690-
T* hc,
691-
T* vc)
711+
T* hphi,
712+
T* sphi,
713+
T* hcc,
714+
T* scc,
715+
T* vcc)
692716
{
693717
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");
694718

@@ -714,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
714738
// update hphi
715739
syncmem_complex_op()(hphi, psi_iter + nband * this->dim, this->dim * nband);
716740

741+
#ifdef __DSP
742+
ModuleBase::gemm_op_mt<T, Device>()
743+
#else
744+
ModuleBase::gemm_op<T, Device>()
745+
#endif
746+
('N',
747+
'N',
748+
this->dim,
749+
nband,
750+
nbase,
751+
this->one,
752+
this->sphi,
753+
this->dim,
754+
this->vcc,
755+
this->nbase_x,
756+
this->zero,
757+
psi_iter + nband * this->dim,
758+
this->dim);
759+
760+
// update sphi
761+
syncmem_complex_op()(sphi, psi_iter + nband * this->dim, this->dim * nband);
762+
717763
nbase = nband;
718764

719765
// set hcc/scc/vcc to 0
@@ -776,6 +822,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
776822

777823
template <typename T, typename Device>
778824
int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
825+
const HPsiFunc& spsi_func,
779826
T* psi_in,
780827
const int psi_in_dmax,
781828
Real* eigenvalue_in_hsolver,
@@ -791,7 +838,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
791838
do
792839
{
793840

794-
sum_iter += this->diag_once(hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
841+
sum_iter += this->diag_once(hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
795842

796843
++ntry;
797844

source/source_hsolver/diago_dav_subspace.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Diago_DavSubspace
4141
using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
4242

4343
int diag(const HPsiFunc& hpsi_func,
44+
const HPsiFunc& spsi_func,
4445
T* psi_in,
4546
const int psi_in_dmax,
4647
Real* eigenvalue_in,
@@ -81,6 +82,9 @@ class Diago_DavSubspace
8182
/// the product of H and psi in the reduced basis set
8283
T* hphi = nullptr;
8384

85+
/// the product of S and psi in the reduced basis set
86+
T* sphi = nullptr;
87+
8488
/// Hamiltonian on the reduced basis
8589
T* hcc = nullptr;
8690

@@ -96,23 +100,33 @@ class Diago_DavSubspace
96100
base_device::AbacusDevice_t device = {};
97101

98102
void cal_grad(const HPsiFunc& hpsi_func,
103+
const HPsiFunc& spsi_func,
99104
const int& dim,
100105
const int& nbase,
101106
const int& notconv,
102107
T* psi_iter,
103108
T* hphi,
109+
T* spsi,
104110
T* vcc,
105111
const int* unconv,
106112
std::vector<Real>* eigenvalue_iter);
107113

108-
void cal_elem(const int& dim, int& nbase, const int& notconv, const T* psi_iter, const T* hphi, T* hcc, T* scc);
114+
void cal_elem(const int& dim,
115+
int& nbase,
116+
const int& notconv,
117+
const T* psi_iter,
118+
const T* sphi,
119+
const T* hphi,
120+
T* hcc,
121+
T* scc);
109122

110123
void refresh(const int& dim,
111124
const int& nband,
112125
int& nbase,
113126
const Real* eigenvalue,
114127
T* psi_iter,
115128
T* hphi,
129+
T* sphi,
116130
T* hcc,
117131
T* scc,
118132
T* vcc);
@@ -134,6 +148,7 @@ class Diago_DavSubspace
134148
T* vcc);
135149

136150
int diag_once(const HPsiFunc& hpsi_func,
151+
const HPsiFunc& spsi_func,
137152
T* psi_in,
138153
const int psi_in_dmax,
139154
Real* eigenvalue_in,

source/source_hsolver/hsolver_pw.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
380380
};
381381
bool scf = this->calculation_type == "nscf" ? false : true;
382382

383+
auto spsi_func = [hm](T* psi_in, T* spsi_out, const int ld_psi, const int nvec) {
384+
hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec);
385+
};
386+
383387
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
384388
psi.get_nbands(),
385389
psi.get_k_first() ? psi.get_current_ngk()
@@ -393,7 +397,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
393397
PARAM.inp.nb2d);
394398

395399
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
396-
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
400+
dav_subspace
401+
.diag(hpsi_func, spsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, this->ethr_band, scf));
397402
}
398403
else if (this->method == "dav")
399404
{

source/source_lcao/module_lr/hsolver_lrtd.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,8 @@ namespace LR
105105
PARAM.inp.diag_subspace,
106106
PARAM.inp.nb2d);
107107
std::vector<double> ethr_band(nband, diag_ethr);
108-
hsolver::DiagoIterAssist<T>::avg_iter
109-
+= static_cast<double>(dav_subspace.diag(
110-
hpsi_func, psi,
111-
dim,
112-
eigenvalue.data(),
113-
ethr_band,
114-
false /*scf*/));
108+
hsolver::DiagoIterAssist<T>::avg_iter += static_cast<double>(
109+
dav_subspace.diag(hpsi_func, spsi_func, psi, dim, eigenvalue.data(), ethr_band, false /*scf*/));
115110
}
116111
else if (method == "cg")
117112
{

0 commit comments

Comments
 (0)