Skip to content

Commit e9578fd

Browse files
committed
Refactor hsolver function to do subspace diagonalization, standard & generalized, used in cg
1 parent 9b6651c commit e9578fd

File tree

7 files changed

+167
-82
lines changed

7 files changed

+167
-82
lines changed

source/source_hsolver/diago_cg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ template <typename T, typename Device>
2828
DiagoCG<T, Device>::DiagoCG(const std::string& basis_type,
2929
const std::string& calculation,
3030
const bool& need_subspace,
31-
const Func& subspace_func,
31+
const SubspaceFunc& subspace_func,
3232
const Real& pw_diag_thr,
3333
const int& pw_diag_nmax,
3434
const int& nproc_in_pool)
@@ -569,7 +569,7 @@ bool DiagoCG<T, Device>::test_exit_cond(const int& ntry, const int& notconv) con
569569
// In non-self consistent calculation, do until totally converged.
570570
const bool f2 = !scf && notconv > 0;
571571
// if self consistent calculation, if not converged > 5,
572-
// using diagH_subspace and cg method again. ntry++
572+
// using diag_subspace and cg method again. ntry++
573573
const bool f3 = scf && notconv > 5;
574574
return f1 && (f2 || f3);
575575
}

source/source_hsolver/diago_iter_assist.cpp

Lines changed: 121 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,39 +18,52 @@ namespace hsolver
1818
// Produces on output n_band eigenvectors (n_band <= nstart) in evc.
1919
//----------------------------------------------------------------------
2020
template <typename T, typename Device>
21-
void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>* const pHamilt, // hamiltonian operator carrier
21+
void DiagoIterAssist<T, Device>::diag_subspace(const hamilt::Hamilt<T, Device>* const pHamilt, // hamiltonian operator carrier
2222
const psi::Psi<T, Device>& psi, // [in] wavefunction
23-
psi::Psi<T, Device>& evc, // [out] wavefunction
23+
psi::Psi<T, Device>& evc, // [out] wavefunction, eigenvectors
2424
Real* en, // [out] eigenvalues
25-
int n_band // [in] number of bands to be calculated, also number of rows
25+
int n_band, // [in] number of bands to be calculated, also number of rows
2626
// of evc, if set to 0, n_band = nstart, default 0
27+
const bool S_orth // [in] if true, psi is assumed to be already S-orthogonalized
2728
)
2829
{
29-
ModuleBase::TITLE("DiagoAssist", "diag_subspace");
30-
ModuleBase::timer::tick("DiagoAssist", "diag_subspace");
30+
ModuleBase::TITLE("DiagoIterAssist", "diag_subspace");
31+
ModuleBase::timer::tick("DiagoIterAssist", "diag_subspace");
3132

3233
// two case:
3334
// 1. pw base: nstart = n_band, psi(nbands * npwx)
3435
// 2. lcao_in_pw base: nstart >= n_band, psi(NLOCAL * npwx)
3536
const int nstart = psi.get_nbands();
37+
// n_band = 0 means default, set n_band = nstart
3638
if (n_band == 0)
3739
{
3840
n_band = nstart;
3941
}
4042
assert(n_band <= nstart);
4143

44+
// scc is overlap (optional, only needed if input is not s-orthogonal)
4245
T *hcc = nullptr, *scc = nullptr, *vcc = nullptr;
46+
47+
// hcc is reduced hamiltonian matrix
4348
resmem_complex_op()(hcc, nstart * nstart, "DiagSub::hcc");
44-
resmem_complex_op()(scc, nstart * nstart, "DiagSub::scc");
45-
resmem_complex_op()(vcc, nstart * nstart, "DiagSub::vcc");
4649
setmem_complex_op()(hcc, 0, nstart * nstart);
47-
setmem_complex_op()(scc, 0, nstart * nstart);
50+
51+
// scc is overlap matrix, only needed when psi is not orthogonal
52+
if(!S_orth){
53+
resmem_complex_op()(scc, nstart * nstart, "DiagSub::scc");
54+
setmem_complex_op()(scc, 0, nstart * nstart);
55+
}
56+
57+
// vcc is eigenvector matrix of the reduced generalized eigenvalue problem
58+
resmem_complex_op()(vcc, nstart * nstart, "DiagSub::vcc");
4859
setmem_complex_op()(vcc, 0, nstart * nstart);
4960

61+
// dmin is the active number of plane waves or atomic orbitals
62+
// dmax is the leading dimension of psi
5063
const int dmin = psi.get_current_ngk();
5164
const int dmax = psi.get_nbasis();
5265

53-
T* temp = nullptr;
66+
T *temp = nullptr; /// temporary array for calculation of evc
5467
bool in_place = false; ///< if temp and evc share the same memory
5568
if (psi.get_pointer() != evc.get_pointer() && psi.get_nbands() == evc.get_nbands())
5669
{ // use memory of evc as temp
@@ -65,10 +78,10 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
6578
{ // code block to calculate hcc and scc
6679
setmem_complex_op()(temp, 0, nstart * dmax);
6780

68-
T* hphi = temp;
81+
T *hpsi = temp;
6982
// do hPsi for all bands
7083
psi::Range all_bands_range(1, psi.get_current_k(), 0, nstart - 1);
71-
hpsi_info hpsi_in(&psi, all_bands_range, hphi);
84+
hpsi_info hpsi_in(&psi, all_bands_range, hpsi);
7285
pHamilt->ops->hPsi(hpsi_in);
7386

7487
ModuleBase::gemm_op<T, Device>()('C',
@@ -79,40 +92,50 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
7992
&one,
8093
psi.get_pointer(),
8194
dmax,
82-
hphi,
95+
hpsi,
8396
dmax,
8497
&zero,
8598
hcc,
8699
nstart);
87100

88-
T* sphi = temp;
89-
// do sPsi for all bands
90-
pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart);
91-
92-
ModuleBase::gemm_op<T, Device>()('C',
93-
'N',
94-
nstart,
95-
nstart,
96-
dmin,
97-
&one,
98-
psi.get_pointer(),
99-
dmax,
100-
sphi,
101-
dmax,
102-
&zero,
103-
scc,
104-
nstart);
101+
if(!S_orth){
102+
// Only calculate S_sub if not orthogonal
103+
T *spsi = temp;
104+
// do sPsi for all bands
105+
pHamilt->sPsi(psi.get_pointer(), spsi, dmax, dmin, nstart);
106+
107+
ModuleBase::gemm_op<T, Device>()('C',
108+
'N',
109+
nstart,
110+
nstart,
111+
dmin,
112+
&one,
113+
psi.get_pointer(),
114+
dmax,
115+
spsi,
116+
dmax,
117+
&zero,
118+
scc,
119+
nstart);
120+
}
105121
}
106122

107123
if (GlobalV::NPROC_IN_POOL > 1)
108124
{
109125
Parallel_Reduce::reduce_pool(hcc, nstart * nstart);
110-
Parallel_Reduce::reduce_pool(scc, nstart * nstart);
126+
if(!S_orth){
127+
Parallel_Reduce::reduce_pool(scc, nstart * nstart);
128+
}
111129
}
112130

113-
// after generation of H and S matrix, diag them
114-
DiagoIterAssist::diagH_LAPACK(nstart, n_band, hcc, scc, nstart, en, vcc);
115-
131+
// after generation of H and (optionally) S matrix, diag them
132+
if (S_orth) {
133+
// Solve standard eigenproblem: H_sub * y = lambda * y
134+
DiagoIterAssist::diag_heevx(nstart, n_band, hcc, nstart, en, vcc);
135+
} else {
136+
// Solve generalized eigenproblem: H_sub * y = lambda * S_sub * y
137+
DiagoIterAssist::diag_hegvd(nstart, n_band, hcc, scc, nstart, en, vcc);
138+
}
116139

117140
const int ld_temp = in_place ? dmax : dmin;
118141

@@ -138,14 +161,16 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
138161
delmem_complex_op()(temp);
139162
}
140163
delmem_complex_op()(hcc);
141-
delmem_complex_op()(scc);
164+
if(!S_orth){
165+
delmem_complex_op()(scc);
166+
}
142167
delmem_complex_op()(vcc);
143168

144169
ModuleBase::timer::tick("DiagoAssist", "diag_subspace");
145170
}
146171

147172
template <typename T, typename Device>
148-
void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>* pHamilt,
173+
void DiagoIterAssist<T, Device>::diag_subspace_init(hamilt::Hamilt<T, Device>* pHamilt,
149174
const T* psi,
150175
int psi_nr,
151176
int psi_nc,
@@ -154,8 +179,8 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
154179
const std::function<void(T*, const int)>& add_to_hcc,
155180
const std::function<void(const T* const, const int, const int)>& export_vcc)
156181
{
157-
ModuleBase::TITLE("DiagoIterAssist", "diagH_subspace_init");
158-
ModuleBase::timer::tick("DiagoIterAssist", "diagH_subspace_init");
182+
ModuleBase::TITLE("DiagoIterAssist", "diag_subspace_init");
183+
ModuleBase::timer::tick("DiagoIterAssist", "diag_subspace_init");
159184

160185
// two case:
161186
// 1. pw base: nstart = n_band, psi(nbands * npwx)
@@ -170,7 +195,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
170195
if (pHamilt->ops == nullptr)
171196
{
172197
ModuleBase::WARNING(
173-
"DiagoIterAssist::diagH_subspace_init",
198+
"DiagoIterAssist::diag_subspace_init",
174199
"Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n");
175200
for (int iband = 0; iband < n_band; iband++)
176201
{
@@ -291,7 +316,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
291316
}
292317
}*/
293318

294-
DiagoIterAssist::diagH_LAPACK(nstart, n_band, hcc, scc, nstart, en, vcc);
319+
DiagoIterAssist::diag_hegvd(nstart, n_band, hcc, scc, nstart, en, vcc);
295320

296321
export_vcc(vcc, nstart, n_band);
297322

@@ -353,22 +378,59 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
353378
delmem_complex_op()(hcc);
354379
delmem_complex_op()(scc);
355380
delmem_complex_op()(vcc);
356-
ModuleBase::timer::tick("DiagoIterAssist", "diagH_subspace_init");
381+
ModuleBase::timer::tick("DiagoIterAssist", "diag_subspace_init");
382+
}
383+
384+
template <typename T, typename Device>
385+
void DiagoIterAssist<T, Device>::diag_heevx(const int matrix_size,
386+
const int num_eigenpairs,
387+
const T *h,
388+
const int ldh,
389+
Real *e, // always in CPU
390+
T *v)
391+
{
392+
ModuleBase::TITLE("DiagoIterAssist", "diag_heevx");
393+
ModuleBase::timer::tick("DiagoIterAssist", "diag_heevx");
394+
395+
Real *eigenvalues = nullptr;
396+
// device memory for eigenvalues
397+
resmem_var_op()(eigenvalues, matrix_size);
398+
setmem_var_op()(eigenvalues, 0, matrix_size);
399+
400+
// (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors);
401+
heevx_op<T, Device>()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v);
402+
403+
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
404+
{
405+
#if ((defined __CUDA) || (defined __ROCM))
406+
// eigenvalues to e, from device to host
407+
syncmem_var_d2h_op()(e, eigenvalues, num_eigenpairs);
408+
#endif
409+
}
410+
else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
411+
{
412+
// eigenvalues to e
413+
syncmem_var_op()(e, eigenvalues, num_eigenpairs);
414+
}
415+
416+
delmem_var_op()(eigenvalues);
417+
418+
ModuleBase::timer::tick("DiagoIterAssist", "diag_heevx");
357419
}
358420

359421
template <typename T, typename Device>
360-
void DiagoIterAssist<T, Device>::diagH_LAPACK(const int nstart,
422+
void DiagoIterAssist<T, Device>::diag_hegvd(const int nstart,
361423
const int nbands,
362-
const T* hcc,
363-
const T* scc,
424+
const T *hcc,
425+
const T *scc,
364426
const int ldh, // nstart
365-
Real* e, // always in CPU
366-
T* vcc)
427+
Real *e, // always in CPU
428+
T *vcc)
367429
{
368-
ModuleBase::TITLE("DiagoIterAssist", "diagH_LAPACK");
369-
ModuleBase::timer::tick("DiagoIterAssist", "diagH_LAPACK");
430+
ModuleBase::TITLE("DiagoIterAssist", "diag_hegvd");
431+
ModuleBase::timer::tick("DiagoIterAssist", "diag_hegvd");
370432

371-
Real* eigenvalues = nullptr;
433+
Real *eigenvalues = nullptr;
372434
resmem_var_op()(eigenvalues, nstart);
373435
setmem_var_op()(eigenvalues, 0, nstart);
374436

@@ -404,7 +466,7 @@ void DiagoIterAssist<T, Device>::diagH_LAPACK(const int nstart,
404466
// dngvx_op<Real, Device>()(ctx, nstart, ldh, hcc, scc, nbands, res, vcc);
405467
// }
406468

407-
ModuleBase::timer::tick("DiagoIterAssist", "diagH_LAPACK");
469+
ModuleBase::timer::tick("DiagoIterAssist", "diag_hegvd");
408470
}
409471

410472
template <typename T, typename Device>
@@ -428,10 +490,10 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
428490
{ // code block to calculate hcc and scc
429491
setmem_complex_op()(temp, 0, nstart * dmax);
430492

431-
T* hphi = temp;
493+
T* hpsi = temp;
432494
// do hPsi for all bands
433495
psi::Range all_bands_range(1, psi.get_current_k(), 0, nstart - 1);
434-
hpsi_info hpsi_in(&psi, all_bands_range, hphi);
496+
hpsi_info hpsi_in(&psi, all_bands_range, hpsi);
435497
pHamilt->ops->hPsi(hpsi_in);
436498

437499
ModuleBase::gemm_op<T, Device>()('C',
@@ -442,15 +504,15 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
442504
&one,
443505
psi.get_pointer(),
444506
dmax,
445-
hphi,
507+
hpsi,
446508
dmax,
447509
&zero,
448510
hcc,
449511
nstart);
450512

451-
T* sphi = temp;
513+
T* spsi = temp;
452514
// do sPsi for all bands
453-
pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart);
515+
pHamilt->sPsi(psi.get_pointer(), spsi, dmax, dmin, nstart);
454516

455517
ModuleBase::gemm_op<T, Device>()('C',
456518
'N',
@@ -460,7 +522,7 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
460522
&one,
461523
psi.get_pointer(),
462524
dmax,
463-
sphi,
525+
spsi,
464526
dmax,
465527
&zero,
466528
scc,
@@ -496,7 +558,7 @@ void DiagoIterAssist<T, Device>::diag_responce( const T* hcc,
496558
setmem_complex_op()(vcc, 0, nstart * nstart);
497559

498560
// after generation of H and S matrix, diag them
499-
DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc);
561+
DiagoIterAssist::diag_hegvd(nstart, nstart, hcc, scc, nstart, en, vcc);
500562

501563
{ // code block to calculate tar_mat
502564
ModuleBase::gemm_op<T, Device>()('N',
@@ -538,7 +600,7 @@ void DiagoIterAssist<T, Device>::diag_subspace_psi(const T* hcc,
538600
setmem_complex_op()(vcc, 0, nstart * nstart);
539601

540602
// after generation of H and S matrix, diag them
541-
DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc);
603+
DiagoIterAssist::diag_hegvd(nstart, nstart, hcc, scc, nstart, en, vcc);
542604

543605
{ // code block to calculate tar_mat
544606
const int dmin = evc.get_current_ngk();
@@ -572,7 +634,7 @@ template <typename T, typename Device>
572634
bool DiagoIterAssist<T, Device>::test_exit_cond(const int& ntry, const int& notconv)
573635
{
574636
//================================================================
575-
// If this logical function is true, need to do diagH_subspace
637+
// If this logical function is true, need to do diag_subspace
576638
// and cg again.
577639
//================================================================
578640

@@ -588,7 +650,7 @@ bool DiagoIterAssist<T, Device>::test_exit_cond(const int& ntry, const int& notc
588650
const bool f2 = ((!scf && (notconv > 0)));
589651

590652
// if self consistent calculation, if not converged > 5,
591-
// using diagH_subspace and cg method again. ntry++
653+
// using diag_subspace and cg method again. ntry++
592654
const bool f3 = ((scf && (notconv > 5)));
593655
return (f1 && (f2 || f3));
594656
}

0 commit comments

Comments
 (0)