@@ -18,39 +18,52 @@ namespace hsolver
1818// Produces on output n_band eigenvectors (n_band <= nstart) in evc.
1919// ----------------------------------------------------------------------
2020template <typename T, typename Device>
21- void DiagoIterAssist<T, Device>::diagH_subspace (const hamilt::Hamilt<T, Device>* const pHamilt, // hamiltonian operator carrier
21+ void DiagoIterAssist<T, Device>::diag_subspace (const hamilt::Hamilt<T, Device>* const pHamilt, // hamiltonian operator carrier
2222 const psi::Psi<T, Device>& psi, // [in] wavefunction
23- psi::Psi<T, Device>& evc, // [out] wavefunction
23+ psi::Psi<T, Device>& evc, // [out] wavefunction, eigenvectors
2424 Real* en, // [out] eigenvalues
25- int n_band // [in] number of bands to be calculated, also number of rows
25+ int n_band, // [in] number of bands to be calculated, also number of rows
2626 // of evc, if set to 0, n_band = nstart, default 0
27+ const bool S_orth // [in] if true, psi is assumed to be already S-orthogonalized
2728)
2829{
29- ModuleBase::TITLE (" DiagoAssist " , " diag_subspace" );
30- ModuleBase::timer::tick (" DiagoAssist " , " diag_subspace" );
30+ ModuleBase::TITLE (" DiagoIterAssist " , " diag_subspace" );
31+ ModuleBase::timer::tick (" DiagoIterAssist " , " diag_subspace" );
3132
3233 // two case:
3334 // 1. pw base: nstart = n_band, psi(nbands * npwx)
3435 // 2. lcao_in_pw base: nstart >= n_band, psi(NLOCAL * npwx)
3536 const int nstart = psi.get_nbands ();
37+ // n_band = 0 means default, set n_band = nstart
3638 if (n_band == 0 )
3739 {
3840 n_band = nstart;
3941 }
4042 assert (n_band <= nstart);
4143
44+ // scc is overlap (optional, only needed if input is not s-orthogonal)
4245 T *hcc = nullptr , *scc = nullptr , *vcc = nullptr ;
46+
47+ // hcc is reduced hamiltonian matrix
4348 resmem_complex_op ()(hcc, nstart * nstart, " DiagSub::hcc" );
44- resmem_complex_op ()(scc, nstart * nstart, " DiagSub::scc" );
45- resmem_complex_op ()(vcc, nstart * nstart, " DiagSub::vcc" );
4649 setmem_complex_op ()(hcc, 0 , nstart * nstart);
47- setmem_complex_op ()(scc, 0 , nstart * nstart);
50+
51+ // scc is overlap matrix, only needed when psi is not orthogonal
52+ if (!S_orth){
53+ resmem_complex_op ()(scc, nstart * nstart, " DiagSub::scc" );
54+ setmem_complex_op ()(scc, 0 , nstart * nstart);
55+ }
56+
57+ // vcc is eigenvector matrix of the reduced generalized eigenvalue problem
58+ resmem_complex_op ()(vcc, nstart * nstart, " DiagSub::vcc" );
4859 setmem_complex_op ()(vcc, 0 , nstart * nstart);
4960
61+ // dmin is the active number of plane waves or atomic orbitals
62+ // dmax is the leading dimension of psi
5063 const int dmin = psi.get_current_ngk ();
5164 const int dmax = psi.get_nbasis ();
5265
53- T* temp = nullptr ;
66+ T * temp = nullptr ; // / temporary array for calculation of evc
5467 bool in_place = false ; // /< if temp and evc share the same memory
5568 if (psi.get_pointer () != evc.get_pointer () && psi.get_nbands () == evc.get_nbands ())
5669 { // use memory of evc as temp
@@ -65,10 +78,10 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
6578 { // code block to calculate hcc and scc
6679 setmem_complex_op ()(temp, 0 , nstart * dmax);
6780
68- T* hphi = temp;
81+ T *hpsi = temp;
6982 // do hPsi for all bands
7083 psi::Range all_bands_range (1 , psi.get_current_k (), 0 , nstart - 1 );
71- hpsi_info hpsi_in (&psi, all_bands_range, hphi );
84+ hpsi_info hpsi_in (&psi, all_bands_range, hpsi );
7285 pHamilt->ops ->hPsi (hpsi_in);
7386
7487 ModuleBase::gemm_op<T, Device>()(' C' ,
@@ -79,40 +92,50 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
7992 &one,
8093 psi.get_pointer (),
8194 dmax,
82- hphi ,
95+ hpsi ,
8396 dmax,
8497 &zero,
8598 hcc,
8699 nstart);
87100
88- T* sphi = temp;
89- // do sPsi for all bands
90- pHamilt->sPsi (psi.get_pointer (), sphi, dmax, dmin, nstart);
91-
92- ModuleBase::gemm_op<T, Device>()(' C' ,
93- ' N' ,
94- nstart,
95- nstart,
96- dmin,
97- &one,
98- psi.get_pointer (),
99- dmax,
100- sphi,
101- dmax,
102- &zero,
103- scc,
104- nstart);
101+ if (!S_orth){
102+ // Only calculate S_sub if not orthogonal
103+ T *spsi = temp;
104+ // do sPsi for all bands
105+ pHamilt->sPsi (psi.get_pointer (), spsi, dmax, dmin, nstart);
106+
107+ ModuleBase::gemm_op<T, Device>()(' C' ,
108+ ' N' ,
109+ nstart,
110+ nstart,
111+ dmin,
112+ &one,
113+ psi.get_pointer (),
114+ dmax,
115+ spsi,
116+ dmax,
117+ &zero,
118+ scc,
119+ nstart);
120+ }
105121 }
106122
107123 if (GlobalV::NPROC_IN_POOL > 1 )
108124 {
109125 Parallel_Reduce::reduce_pool (hcc, nstart * nstart);
110- Parallel_Reduce::reduce_pool (scc, nstart * nstart);
126+ if (!S_orth){
127+ Parallel_Reduce::reduce_pool (scc, nstart * nstart);
128+ }
111129 }
112130
113- // after generation of H and S matrix, diag them
114- DiagoIterAssist::diagH_LAPACK (nstart, n_band, hcc, scc, nstart, en, vcc);
115-
131+ // after generation of H and (optionally) S matrix, diag them
132+ if (S_orth) {
133+ // Solve standard eigenproblem: H_sub * y = lambda * y
134+ DiagoIterAssist::diag_heevx (nstart, n_band, hcc, nstart, en, vcc);
135+ } else {
136+ // Solve generalized eigenproblem: H_sub * y = lambda * S_sub * y
137+ DiagoIterAssist::diag_hegvd (nstart, n_band, hcc, scc, nstart, en, vcc);
138+ }
116139
117140 const int ld_temp = in_place ? dmax : dmin;
118141
@@ -138,14 +161,16 @@ void DiagoIterAssist<T, Device>::diagH_subspace(const hamilt::Hamilt<T, Device>*
138161 delmem_complex_op ()(temp);
139162 }
140163 delmem_complex_op ()(hcc);
141- delmem_complex_op ()(scc);
164+ if (!S_orth){
165+ delmem_complex_op ()(scc);
166+ }
142167 delmem_complex_op ()(vcc);
143168
144169 ModuleBase::timer::tick (" DiagoAssist" , " diag_subspace" );
145170}
146171
147172template <typename T, typename Device>
148- void DiagoIterAssist<T, Device>::diagH_subspace_init (hamilt::Hamilt<T, Device>* pHamilt,
173+ void DiagoIterAssist<T, Device>::diag_subspace_init (hamilt::Hamilt<T, Device>* pHamilt,
149174 const T* psi,
150175 int psi_nr,
151176 int psi_nc,
@@ -154,8 +179,8 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
154179 const std::function<void (T*, const int )>& add_to_hcc,
155180 const std::function<void (const T* const , const int , const int )>& export_vcc)
156181{
157- ModuleBase::TITLE (" DiagoIterAssist" , " diagH_subspace_init " );
158- ModuleBase::timer::tick (" DiagoIterAssist" , " diagH_subspace_init " );
182+ ModuleBase::TITLE (" DiagoIterAssist" , " diag_subspace_init " );
183+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_subspace_init " );
159184
160185 // two case:
161186 // 1. pw base: nstart = n_band, psi(nbands * npwx)
@@ -170,7 +195,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
170195 if (pHamilt->ops == nullptr )
171196 {
172197 ModuleBase::WARNING (
173- " DiagoIterAssist::diagH_subspace_init " ,
198+ " DiagoIterAssist::diag_subspace_init " ,
174199 " Severe warning: Operators in Hamilt are not allocated yet, will return value of psi to evc directly\n " );
175200 for (int iband = 0 ; iband < n_band; iband++)
176201 {
@@ -291,7 +316,7 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
291316 }
292317 }*/
293318
294- DiagoIterAssist::diagH_LAPACK (nstart, n_band, hcc, scc, nstart, en, vcc);
319+ DiagoIterAssist::diag_hegvd (nstart, n_band, hcc, scc, nstart, en, vcc);
295320
296321 export_vcc (vcc, nstart, n_band);
297322
@@ -353,22 +378,59 @@ void DiagoIterAssist<T, Device>::diagH_subspace_init(hamilt::Hamilt<T, Device>*
353378 delmem_complex_op ()(hcc);
354379 delmem_complex_op ()(scc);
355380 delmem_complex_op ()(vcc);
356- ModuleBase::timer::tick (" DiagoIterAssist" , " diagH_subspace_init" );
381+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_subspace_init" );
382+ }
383+
384+ template <typename T, typename Device>
385+ void DiagoIterAssist<T, Device>::diag_heevx(const int matrix_size,
386+ const int num_eigenpairs,
387+ const T *h,
388+ const int ldh,
389+ Real *e, // always in CPU
390+ T *v)
391+ {
392+ ModuleBase::TITLE (" DiagoIterAssist" , " diag_heevx" );
393+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_heevx" );
394+
395+ Real *eigenvalues = nullptr ;
396+ // device memory for eigenvalues
397+ resmem_var_op ()(eigenvalues, matrix_size);
398+ setmem_var_op ()(eigenvalues, 0 , matrix_size);
399+
400+ // (const Device *d, const int matrix_size, const int lda, const T *A, const int num_eigenpairs, Real *eigenvalues, T *eigenvectors);
401+ heevx_op<T, Device>()(ctx, matrix_size, ldh, h, num_eigenpairs, eigenvalues, v);
402+
403+ if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
404+ {
405+ #if ((defined __CUDA) || (defined __ROCM))
406+ // eigenvalues to e, from device to host
407+ syncmem_var_d2h_op ()(e, eigenvalues, num_eigenpairs);
408+ #endif
409+ }
410+ else if (base_device::get_device_type<Device>(ctx) == base_device::CpuDevice)
411+ {
412+ // eigenvalues to e
413+ syncmem_var_op ()(e, eigenvalues, num_eigenpairs);
414+ }
415+
416+ delmem_var_op ()(eigenvalues);
417+
418+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_heevx" );
357419}
358420
359421template <typename T, typename Device>
360- void DiagoIterAssist<T, Device>::diagH_LAPACK (const int nstart,
422+ void DiagoIterAssist<T, Device>::diag_hegvd (const int nstart,
361423 const int nbands,
362- const T* hcc,
363- const T* scc,
424+ const T * hcc,
425+ const T * scc,
364426 const int ldh, // nstart
365- Real* e, // always in CPU
366- T* vcc)
427+ Real * e, // always in CPU
428+ T * vcc)
367429{
368- ModuleBase::TITLE (" DiagoIterAssist" , " diagH_LAPACK " );
369- ModuleBase::timer::tick (" DiagoIterAssist" , " diagH_LAPACK " );
430+ ModuleBase::TITLE (" DiagoIterAssist" , " diag_hegvd " );
431+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_hegvd " );
370432
371- Real* eigenvalues = nullptr ;
433+ Real * eigenvalues = nullptr ;
372434 resmem_var_op ()(eigenvalues, nstart);
373435 setmem_var_op ()(eigenvalues, 0 , nstart);
374436
@@ -404,7 +466,7 @@ void DiagoIterAssist<T, Device>::diagH_LAPACK(const int nstart,
404466 // dngvx_op<Real, Device>()(ctx, nstart, ldh, hcc, scc, nbands, res, vcc);
405467 // }
406468
407- ModuleBase::timer::tick (" DiagoIterAssist" , " diagH_LAPACK " );
469+ ModuleBase::timer::tick (" DiagoIterAssist" , " diag_hegvd " );
408470}
409471
410472template <typename T, typename Device>
@@ -428,10 +490,10 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
428490 { // code block to calculate hcc and scc
429491 setmem_complex_op ()(temp, 0 , nstart * dmax);
430492
431- T* hphi = temp;
493+ T* hpsi = temp;
432494 // do hPsi for all bands
433495 psi::Range all_bands_range (1 , psi.get_current_k (), 0 , nstart - 1 );
434- hpsi_info hpsi_in (&psi, all_bands_range, hphi );
496+ hpsi_info hpsi_in (&psi, all_bands_range, hpsi );
435497 pHamilt->ops ->hPsi (hpsi_in);
436498
437499 ModuleBase::gemm_op<T, Device>()(' C' ,
@@ -442,15 +504,15 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
442504 &one,
443505 psi.get_pointer (),
444506 dmax,
445- hphi ,
507+ hpsi ,
446508 dmax,
447509 &zero,
448510 hcc,
449511 nstart);
450512
451- T* sphi = temp;
513+ T* spsi = temp;
452514 // do sPsi for all bands
453- pHamilt->sPsi (psi.get_pointer (), sphi , dmax, dmin, nstart);
515+ pHamilt->sPsi (psi.get_pointer (), spsi , dmax, dmin, nstart);
454516
455517 ModuleBase::gemm_op<T, Device>()(' C' ,
456518 ' N' ,
@@ -460,7 +522,7 @@ void DiagoIterAssist<T, Device>::cal_hs_subspace(const hamilt::Hamilt<T, Device>
460522 &one,
461523 psi.get_pointer (),
462524 dmax,
463- sphi ,
525+ spsi ,
464526 dmax,
465527 &zero,
466528 scc,
@@ -496,7 +558,7 @@ void DiagoIterAssist<T, Device>::diag_responce( const T* hcc,
496558 setmem_complex_op ()(vcc, 0 , nstart * nstart);
497559
498560 // after generation of H and S matrix, diag them
499- DiagoIterAssist::diagH_LAPACK (nstart, nstart, hcc, scc, nstart, en, vcc);
561+ DiagoIterAssist::diag_hegvd (nstart, nstart, hcc, scc, nstart, en, vcc);
500562
501563 { // code block to calculate tar_mat
502564 ModuleBase::gemm_op<T, Device>()(' N' ,
@@ -538,7 +600,7 @@ void DiagoIterAssist<T, Device>::diag_subspace_psi(const T* hcc,
538600 setmem_complex_op ()(vcc, 0 , nstart * nstart);
539601
540602 // after generation of H and S matrix, diag them
541- DiagoIterAssist::diagH_LAPACK (nstart, nstart, hcc, scc, nstart, en, vcc);
603+ DiagoIterAssist::diag_hegvd (nstart, nstart, hcc, scc, nstart, en, vcc);
542604
543605 { // code block to calculate tar_mat
544606 const int dmin = evc.get_current_ngk ();
@@ -572,7 +634,7 @@ template <typename T, typename Device>
572634bool DiagoIterAssist<T, Device>::test_exit_cond(const int & ntry, const int & notconv)
573635{
574636 // ================================================================
575- // If this logical function is true, need to do diagH_subspace
637+ // If this logical function is true, need to do diag_subspace
576638 // and cg again.
577639 // ================================================================
578640
@@ -588,7 +650,7 @@ bool DiagoIterAssist<T, Device>::test_exit_cond(const int& ntry, const int& notc
588650 const bool f2 = ((!scf && (notconv > 0 )));
589651
590652 // if self consistent calculation, if not converged > 5,
591- // using diagH_subspace and cg method again. ntry++
653+ // using diag_subspace and cg method again. ntry++
592654 const bool f3 = ((scf && (notconv > 5 )));
593655 return (f1 && (f2 || f3));
594656}
0 commit comments