@@ -96,6 +96,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
9696
9797template <typename T, typename Device>
9898int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
99+ const HPsiFunc& spsi_func,
99100 T* psi_in,
100101 const int psi_in_dmax,
101102 Real* eigenvalue_in_hsolver,
@@ -134,7 +135,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
134135 // hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
135136 hpsi_func (this ->psi_in_iter , this ->hphi , this ->dim , this ->notconv );
136137
137- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
138+ // compute s*psi_in_iter
139+ // sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
140+ spsi_func (this ->psi_in_iter , this ->sphi , this ->dim , this ->notconv );
141+
142+ this ->cal_elem (this ->dim , nbase, this ->notconv , this ->sphi , this ->hphi , this ->hcc , this ->scc );
138143
139144 this ->diag_zhegvx (nbase, this ->notconv , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
140145
@@ -161,7 +166,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
161166 unconv.data (),
162167 &eigenvalue_iter);
163168
164- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
169+ this ->cal_elem (this ->dim , nbase, this ->notconv , this ->sphi , this ->hphi , this ->hcc , this ->scc );
165170
166171 this ->diag_zhegvx (nbase, this ->n_band , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
167172
@@ -405,7 +410,7 @@ template <typename T, typename Device>
405410void Diago_DavSubspace<T, Device>::cal_elem(const int & dim,
406411 int & nbase,
407412 const int & notconv,
408- const T* psi_iter ,
413+ const T* spsi ,
409414 const T* hphi,
410415 T* hcc,
411416 T* scc)
@@ -416,39 +421,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
416421 ModuleBase::gemm_op_mt<T, Device>()
417422#else
418423 ModuleBase::gemm_op<T, Device>()
419- #endif
420- (' C' ,
421- ' N' ,
422- nbase + notconv,
423- notconv,
424- this ->dim ,
425- this ->one ,
426- psi_iter ,
427- this ->dim ,
428- &hphi[nbase * this ->dim ],
429- this ->dim ,
430- this ->zero ,
431- &hcc[nbase * this ->nbase_x ],
432- this ->nbase_x );
424+ #endif
425+ (' C' ,
426+ ' N' ,
427+ nbase + notconv,
428+ notconv,
429+ this ->dim ,
430+ this ->one ,
431+ spsi ,
432+ this ->dim ,
433+ &hphi[nbase * this ->dim ],
434+ this ->dim ,
435+ this ->zero ,
436+ &hcc[nbase * this ->nbase_x ],
437+ this ->nbase_x );
433438
434439#ifdef __DSP
435440 ModuleBase::gemm_op_mt<T, Device>()
436441#else
437442 ModuleBase::gemm_op<T, Device>()
438443#endif
439- (' C' ,
440- ' N' ,
441- nbase + notconv,
442- notconv,
443- this ->dim ,
444- this ->one ,
445- psi_iter ,
446- this ->dim ,
447- psi_iter + nbase * this ->dim ,
448- this ->dim ,
449- this ->zero ,
450- &scc[nbase * this ->nbase_x ],
451- this ->nbase_x );
444+ (' C' ,
445+ ' N' ,
446+ nbase + notconv,
447+ notconv,
448+ this ->dim ,
449+ this ->one ,
450+ spsi ,
451+ this ->dim ,
452+ spsi + nbase * this ->dim ,
453+ this ->dim ,
454+ this ->zero ,
455+ &scc[nbase * this ->nbase_x ],
456+ this ->nbase_x );
452457
453458#ifdef __MPI
454459 if (this ->diag_comm .nproc > 1 )
@@ -776,6 +781,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
776781
777782template <typename T, typename Device>
778783int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
784+ const HPsiFunc& spsi_func,
779785 T* psi_in,
780786 const int psi_in_dmax,
781787 Real* eigenvalue_in_hsolver,
@@ -791,7 +797,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
791797 do
792798 {
793799
794- sum_iter += this ->diag_once (hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
800+ sum_iter += this ->diag_once (hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
795801
796802 ++ntry;
797803
0 commit comments