@@ -49,12 +49,6 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
4949 ModuleBase::ComplexMatrix vc (nbase_x, nbase_x); // Eigenvectors of hc
5050 std::vector<double > eigenvalue (nbase_x); // the lowest N eigenvalues of hc
5151
52- std::vector<std::complex <double >> psi_m (dim);
53- std::vector<std::complex <double >> hpsi (dim);
54- std::vector<std::complex <double >> spsi (dim);
55- std::vector<std::complex <double >> ppsi (dim);
56- std::vector<std::complex <double >> respsi (dim);
57-
5852 std::vector<bool > convflag (nband, false ); // convflag[m] = true if the m th band is convergent
5953 std::vector<int > unconv (nband); // unconv[m] store the number of the m th unconvergent band
6054
@@ -142,8 +136,7 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
142136 sp,
143137 vc,
144138 unconv.data (),
145- eigenvalue.data (),
146- respsi.data ());
139+ eigenvalue.data ());
147140
148141 this ->cal_elem (dim, nbase, this ->notconv , basis, hp, sp, hc, sc);
149142
@@ -235,8 +228,7 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
235228 ModuleBase::ComplexMatrix &sp,
236229 const ModuleBase::ComplexMatrix &vc,
237230 const int *unconv,
238- const double *eigenvalue,
239- std::complex <double > *respsi)
231+ const double *eigenvalue)
240232{
241233 if (test_david == 1 )
242234 ModuleBase::TITLE (" DiagoDavid" , " cal_grad" );
@@ -251,32 +243,63 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
251243 // expand the reduced basis set with the new basis vectors P|R(psi)>...
252244 // in which psi are the last eigenvectors
253245 // we define |R(psi)> as (H-ES)*|Psi>, E = <psi|H|psi>/<psi|S|psi>
254- std::vector<std:: complex < double >> vc_ev_vector (nbase);
246+ ModuleBase::ComplexMatrix vc_ev_vector (notconv, nbase);
255247 for (int m = 0 ; m < notconv; m++)
256248 {
257249 for (int i = 0 ; i < nbase; i++)
258250 {
259- vc_ev_vector[i] = vc (i, unconv[m]);
251+ vc_ev_vector (m, i) = vc (i, unconv[m]);
260252 }
261- int inc = 1 ;
262- char trans = ' N' ;
263- zgemv_ (&trans,
264- &npw,
265- &nbase,
266- &ModuleBase::ONE,
267- hp.c ,
268- &hp.nc ,
269- vc_ev_vector.data (),
270- &inc,
271- &ModuleBase::ZERO,
272- respsi,
273- &inc);
274-
253+ }
254+ ppsi = &basis (nbase, 0 );
255+ int inc = 1 ;
256+ char trans = ' N' ;
257+ char transb = ' N' ;
258+ zgemm_ (&trans,
259+ &transb,
260+ &npw, // m: row of A,C
261+ ¬conv, // n: col of B,C
262+ &nbase, // k: col of A, row of B
263+ &ModuleBase::ONE, // alpha
264+ hp.c , // A
265+ &hp.nc , // LDA: if(N) max(1,m) if(T) max(1,k)
266+ vc_ev_vector.c , // B
267+ &vc_ev_vector.nc , // LDB: if(N) max(1,k) if(T) max(1,n)
268+ &ModuleBase::ZERO, // belta
269+ ppsi, // C
270+ &basis.get_nbasis ()); // LDC: if(N) max(1, m)
271+ /* zgemv_(&trans,
272+ &npw,
273+ &nbase,
274+ &ModuleBase::ONE,
275+ hp.c,
276+ &hp.nc,
277+ vc_ev_vector.data(),
278+ &inc,
279+ &ModuleBase::ZERO,
280+ respsi,
281+ &inc);*/
282+ for (int m = 0 ; m < notconv; m++)
283+ {
275284 for (int i = 0 ; i < nbase; i++)
276285 {
277- vc_ev_vector[i] *= -1 * eigenvalue[unconv[m]];
286+ vc_ev_vector (m, i) *= -1 * eigenvalue[unconv[m]];
278287 }
279- zgemv_ (&trans,
288+ }
289+ zgemm_ (&trans,
290+ &transb,
291+ &npw, // m: row of A,C
292+ ¬conv, // n: col of B,C
293+ &nbase, // k: col of A, row of B
294+ &ModuleBase::ONE, // alpha
295+ sp.c , // A
296+ &sp.nc , // LDA: if(N) max(1,m) if(T) max(1,k)
297+ vc_ev_vector.c , // B
298+ &vc_ev_vector.nc , // LDB: if(N) max(1,k) if(T) max(1,n)
299+ &ModuleBase::ONE, // belta
300+ ppsi, // C
301+ &basis.get_nbasis ()); // LDC: if(N) max(1, m)
302+ /* zgemv_(&trans,
280303 &npw,
281304 &nbase,
282305 &ModuleBase::ONE,
@@ -286,7 +309,7 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
286309 &inc,
287310 &ModuleBase::ONE,
288311 respsi,
289- &inc);
312+ &inc);*/
290313
291314 /* ModuleBase::GlobalFunc::ZEROS(respsi, npw);
292315 for (int i = 0; i < nbase; i++)
@@ -301,11 +324,12 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
301324 }
302325 }*/
303326
304-
327+ for (int m = 0 ; m < notconv; m++)
328+ {
305329 ppsi = &basis (nbase + m, 0 );
306330 for (int ig = 0 ; ig < npw; ig++)
307331 {
308- ppsi[ig] = respsi[ig] / this ->precondition [ig];
332+ ppsi[ig] /= this ->precondition [ig];
309333 }
310334 }
311335
@@ -320,8 +344,8 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
320344 phm_in->sPsi (&basis (nbase + m, 0 ), &sp (nbase + m, 0 ), (size_t )npw);
321345 }
322346 // first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
323- char trans = ' C' ;
324- char transb = ' N' ;
347+ trans = ' C' ;
348+ transb = ' N' ;
325349 // calculate the square matrix for future lagranges
326350 zgemm_ (&trans,
327351 &transb,
0 commit comments