@@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
5454 resmem_complex_op ()(this ->hphi , this ->nbase_x * this ->dim , " DAV::hphi" );
5555 setmem_complex_op ()(this ->hphi , 0 , this ->nbase_x * this ->dim );
5656
57+ // the product of S and psi in the reduced psi set
58+ resmem_complex_op ()(this ->sphi , this ->nbase_x * this ->dim , " DAV::sphi" );
59+ setmem_complex_op ()(this ->sphi , 0 , this ->nbase_x * this ->dim );
60+
5761 // Hamiltonian on the reduced psi set
5862 resmem_complex_op ()(this ->hcc , this ->nbase_x * this ->nbase_x , " DAV::hcc" );
5963 setmem_complex_op ()(this ->hcc , 0 , this ->nbase_x * this ->nbase_x );
@@ -139,7 +143,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
139143 // sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
140144 spsi_func (this ->psi_in_iter , this ->sphi , this ->dim , this ->notconv );
141145
142- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->sphi , this ->hphi , this ->hcc , this ->scc );
146+ this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this -> sphi , this ->hphi , this ->hcc , this ->scc );
143147
144148 this ->diag_zhegvx (nbase, this ->notconv , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
145149
@@ -157,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
157161 dav_iter++;
158162
159163 this ->cal_grad (hpsi_func,
164+ spsi_func,
160165 this ->dim ,
161166 nbase,
162167 this ->notconv ,
163168 this ->psi_in_iter ,
164169 this ->hphi ,
170+ this ->sphi ,
165171 this ->vcc ,
166172 unconv.data (),
167173 &eigenvalue_iter);
168174
169- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->sphi , this ->hphi , this ->hcc , this ->scc );
175+ this ->cal_elem (this ->dim ,
176+ nbase,
177+ this ->notconv ,
178+ this ->psi_in_iter ,
179+ this ->sphi ,
180+ this ->hphi ,
181+ this ->hcc ,
182+ this ->scc );
170183
171184 this ->diag_zhegvx (nbase, this ->n_band , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
172185
@@ -243,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
243256 eigenvalue_in_hsolver,
244257 this ->psi_in_iter ,
245258 this ->hphi ,
259+ this ->sphi ,
246260 this ->hcc ,
247261 this ->scc ,
248262 this ->vcc );
@@ -260,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
260274
261275template <typename T, typename Device>
262276void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
277+ const HPsiFunc& spsi_func,
263278 const int & dim,
264279 const int & nbase,
265280 const int & notconv,
266281 T* psi_iter,
267282 T* hphi,
283+ T* spsi,
268284 T* vcc,
269285 const int * unconv,
270286 std::vector<Real>* eigenvalue_iter)
@@ -336,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
336352 notconv,
337353 nbase,
338354 this ->one ,
339- psi_iter ,
355+ sphi ,
340356 this ->dim ,
341357 vcc,
342358 this ->nbase_x ,
@@ -401,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
401417 // update hpsi[:, nbase:nbase+notconv]
402418 // hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
403419 hpsi_func (psi_iter + nbase * dim, hphi + nbase * this ->dim , this ->dim , notconv);
420+ spsi_func (psi_iter + nbase * dim, sphi + nbase * this ->dim , this ->dim , notconv);
404421
405422 ModuleBase::timer::tick (" Diago_DavSubspace" , " cal_grad" );
406423 return ;
@@ -410,6 +427,7 @@ template <typename T, typename Device>
410427void Diago_DavSubspace<T, Device>::cal_elem(const int & dim,
411428 int & nbase,
412429 const int & notconv,
430+ const T* psi_iter,
413431 const T* spsi,
414432 const T* hphi,
415433 T* hcc,
@@ -428,7 +446,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
428446 notconv,
429447 this ->dim ,
430448 this ->one ,
431- spsi ,
449+ psi_iter ,
432450 this ->dim ,
433451 &hphi[nbase * this ->dim ],
434452 this ->dim ,
@@ -447,7 +465,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
447465 notconv,
448466 this ->dim ,
449467 this ->one ,
450- spsi ,
468+ psi_iter ,
451469 this ->dim ,
452470 spsi + nbase * this ->dim ,
453471 this ->dim ,
@@ -690,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
690708 const Real* eigenvalue_in_hsolver,
691709 // const psi::Psi<T, Device>& psi,
692710 T* psi_iter,
693- T* hp,
694- T* sp,
695- T* hc,
696- T* vc)
711+ T* hphi,
712+ T* sphi,
713+ T* hcc,
714+ T* scc,
715+ T* vcc)
697716{
698717 ModuleBase::timer::tick (" Diago_DavSubspace" , " refresh" );
699718
@@ -719,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
719738 // update hphi
720739 syncmem_complex_op ()(hphi, psi_iter + nband * this ->dim , this ->dim * nband);
721740
741+ #ifdef __DSP
742+ ModuleBase::gemm_op_mt<T, Device>()
743+ #else
744+ ModuleBase::gemm_op<T, Device>()
745+ #endif
746+ (' N' ,
747+ ' N' ,
748+ this ->dim ,
749+ nband,
750+ nbase,
751+ this ->one ,
752+ this ->sphi ,
753+ this ->dim ,
754+ this ->vcc ,
755+ this ->nbase_x ,
756+ this ->zero ,
757+ psi_iter + nband * this ->dim ,
758+ this ->dim );
759+
760+ // update sphi
761+ syncmem_complex_op ()(sphi, psi_iter + nband * this ->dim , this ->dim * nband);
762+
722763 nbase = nband;
723764
724765 // set hcc/scc/vcc to 0
0 commit comments