@@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
5454 resmem_complex_op ()(this ->hphi , this ->nbase_x * this ->dim , " DAV::hphi" );
5555 setmem_complex_op ()(this ->hphi , 0 , this ->nbase_x * this ->dim );
5656
57+ // the product of S and psi in the reduced psi set
58+ resmem_complex_op ()(this ->sphi , this ->nbase_x * this ->dim , " DAV::sphi" );
59+ setmem_complex_op ()(this ->sphi , 0 , this ->nbase_x * this ->dim );
60+
5761 // Hamiltonian on the reduced psi set
5862 resmem_complex_op ()(this ->hcc , this ->nbase_x * this ->nbase_x , " DAV::hcc" );
5963 setmem_complex_op ()(this ->hcc , 0 , this ->nbase_x * this ->nbase_x );
@@ -96,6 +100,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
96100
97101template <typename T, typename Device>
98102int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
103+ const HPsiFunc& spsi_func,
99104 T* psi_in,
100105 const int psi_in_dmax,
101106 Real* eigenvalue_in_hsolver,
@@ -134,7 +139,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
134139 // hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
135140 hpsi_func (this ->psi_in_iter , this ->hphi , this ->dim , this ->notconv );
136141
137- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
142+ // compute s*psi_in_iter
143+ // sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
144+ spsi_func (this ->psi_in_iter , this ->sphi , this ->dim , this ->notconv );
145+
146+ this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->sphi , this ->hphi , this ->hcc , this ->scc );
138147
139148 this ->diag_zhegvx (nbase, this ->notconv , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
140149
@@ -152,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
152161 dav_iter++;
153162
154163 this ->cal_grad (hpsi_func,
164+ spsi_func,
155165 this ->dim ,
156166 nbase,
157167 this ->notconv ,
158168 this ->psi_in_iter ,
159169 this ->hphi ,
170+ this ->sphi ,
160171 this ->vcc ,
161172 unconv.data (),
162173 &eigenvalue_iter);
163174
164- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
175+ this ->cal_elem (this ->dim ,
176+ nbase,
177+ this ->notconv ,
178+ this ->psi_in_iter ,
179+ this ->sphi ,
180+ this ->hphi ,
181+ this ->hcc ,
182+ this ->scc );
165183
166184 this ->diag_zhegvx (nbase, this ->n_band , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
167185
@@ -238,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
238256 eigenvalue_in_hsolver,
239257 this ->psi_in_iter ,
240258 this ->hphi ,
259+ this ->sphi ,
241260 this ->hcc ,
242261 this ->scc ,
243262 this ->vcc );
@@ -255,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
255274
256275template <typename T, typename Device>
257276void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
277+ const HPsiFunc& spsi_func,
258278 const int & dim,
259279 const int & nbase,
260280 const int & notconv,
261281 T* psi_iter,
262282 T* hphi,
283+ T* spsi,
263284 T* vcc,
264285 const int * unconv,
265286 std::vector<Real>* eigenvalue_iter)
@@ -331,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
331352 notconv,
332353 nbase,
333354 this ->one ,
334- psi_iter ,
355+ sphi ,
335356 this ->dim ,
336357 vcc,
337358 this ->nbase_x ,
@@ -396,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
396417 // update hpsi[:, nbase:nbase+notconv]
397418 // hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
398419 hpsi_func (psi_iter + nbase * dim, hphi + nbase * this ->dim , this ->dim , notconv);
420+ spsi_func (psi_iter + nbase * dim, sphi + nbase * this ->dim , this ->dim , notconv);
399421
400422 ModuleBase::timer::tick (" Diago_DavSubspace" , " cal_grad" );
401423 return ;
@@ -406,6 +428,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
406428 int & nbase,
407429 const int & notconv,
408430 const T* psi_iter,
431+ const T* spsi,
409432 const T* hphi,
410433 T* hcc,
411434 T* scc)
@@ -416,39 +439,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
416439 ModuleBase::gemm_op_mt<T, Device>()
417440#else
418441 ModuleBase::gemm_op<T, Device>()
419- #endif
420- (' C' ,
421- ' N' ,
422- nbase + notconv,
423- notconv,
424- this ->dim ,
425- this ->one ,
426- psi_iter,
427- this ->dim ,
428- &hphi[nbase * this ->dim ],
429- this ->dim ,
430- this ->zero ,
431- &hcc[nbase * this ->nbase_x ],
432- this ->nbase_x );
442+ #endif
443+ (' C' ,
444+ ' N' ,
445+ nbase + notconv,
446+ notconv,
447+ this ->dim ,
448+ this ->one ,
449+ psi_iter,
450+ this ->dim ,
451+ &hphi[nbase * this ->dim ],
452+ this ->dim ,
453+ this ->zero ,
454+ &hcc[nbase * this ->nbase_x ],
455+ this ->nbase_x );
433456
434457#ifdef __DSP
435458 ModuleBase::gemm_op_mt<T, Device>()
436459#else
437460 ModuleBase::gemm_op<T, Device>()
438461#endif
439- (' C' ,
440- ' N' ,
441- nbase + notconv,
442- notconv,
443- this ->dim ,
444- this ->one ,
445- psi_iter,
446- this ->dim ,
447- psi_iter + nbase * this ->dim ,
448- this ->dim ,
449- this ->zero ,
450- &scc[nbase * this ->nbase_x ],
451- this ->nbase_x );
462+ (' C' ,
463+ ' N' ,
464+ nbase + notconv,
465+ notconv,
466+ this ->dim ,
467+ this ->one ,
468+ psi_iter,
469+ this ->dim ,
470+ spsi + nbase * this ->dim ,
471+ this ->dim ,
472+ this ->zero ,
473+ &scc[nbase * this ->nbase_x ],
474+ this ->nbase_x );
452475
453476#ifdef __MPI
454477 if (this ->diag_comm .nproc > 1 )
@@ -685,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
685708 const Real* eigenvalue_in_hsolver,
686709 // const psi::Psi<T, Device>& psi,
687710 T* psi_iter,
688- T* hp,
689- T* sp,
690- T* hc,
691- T* vc)
711+ T* hphi,
712+ T* sphi,
713+ T* hcc,
714+ T* scc,
715+ T* vcc)
692716{
693717 ModuleBase::timer::tick (" Diago_DavSubspace" , " refresh" );
694718
@@ -714,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
714738 // update hphi
715739 syncmem_complex_op ()(hphi, psi_iter + nband * this ->dim , this ->dim * nband);
716740
741+ #ifdef __DSP
742+ ModuleBase::gemm_op_mt<T, Device>()
743+ #else
744+ ModuleBase::gemm_op<T, Device>()
745+ #endif
746+ (' N' ,
747+ ' N' ,
748+ this ->dim ,
749+ nband,
750+ nbase,
751+ this ->one ,
752+ this ->sphi ,
753+ this ->dim ,
754+ this ->vcc ,
755+ this ->nbase_x ,
756+ this ->zero ,
757+ psi_iter + nband * this ->dim ,
758+ this ->dim );
759+
760+ // update sphi
761+ syncmem_complex_op ()(sphi, psi_iter + nband * this ->dim , this ->dim * nband);
762+
717763 nbase = nband;
718764
719765 // set hcc/scc/vcc to 0
@@ -776,6 +822,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
776822
777823template <typename T, typename Device>
778824int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
825+ const HPsiFunc& spsi_func,
779826 T* psi_in,
780827 const int psi_in_dmax,
781828 Real* eigenvalue_in_hsolver,
@@ -791,7 +838,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
791838 do
792839 {
793840
794- sum_iter += this ->diag_once (hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
841+ sum_iter += this ->diag_once (hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
795842
796843 ++ntry;
797844
0 commit comments