Skip to content

Commit eba0ac6

Browse files
authored
Merge pull request #1118 from deepmodeling/HSolver
Perf: update cal_grad() in davidson method
2 parents 6d4cb63 + 0522ce5 commit eba0ac6

File tree

2 files changed

+58
-35
lines changed

2 files changed

+58
-35
lines changed

source/module_hsolver/diago_david.cpp

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
4949
ModuleBase::ComplexMatrix vc(nbase_x, nbase_x); // Eigenvectors of hc
5050
std::vector<double> eigenvalue(nbase_x); // the lowest N eigenvalues of hc
5151

52-
std::vector<std::complex<double>> psi_m(dim);
53-
std::vector<std::complex<double>> hpsi(dim);
54-
std::vector<std::complex<double>> spsi(dim);
55-
std::vector<std::complex<double>> ppsi(dim);
56-
std::vector<std::complex<double>> respsi(dim);
57-
5852
std::vector<bool> convflag(nband, false); // convflag[m] = true if the m th band is convergent
5953
std::vector<int> unconv(nband); // unconv[m] store the number of the m th unconvergent band
6054

@@ -142,8 +136,7 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
142136
sp,
143137
vc,
144138
unconv.data(),
145-
eigenvalue.data(),
146-
respsi.data());
139+
eigenvalue.data());
147140

148141
this->cal_elem(dim, nbase, this->notconv, basis, hp, sp, hc, sc);
149142

@@ -235,8 +228,7 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
235228
ModuleBase::ComplexMatrix &sp,
236229
const ModuleBase::ComplexMatrix &vc,
237230
const int *unconv,
238-
const double *eigenvalue,
239-
std::complex<double> *respsi)
231+
const double *eigenvalue)
240232
{
241233
if (test_david == 1)
242234
ModuleBase::TITLE("DiagoDavid", "cal_grad");
@@ -251,32 +243,63 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
251243
// expand the reduced basis set with the new basis vectors P|R(psi)>...
252244
// in which psi are the last eigenvectors
253245
// we define |R(psi)> as (H-ES)*|Psi>, E = <psi|H|psi>/<psi|S|psi>
254-
std::vector<std::complex<double>> vc_ev_vector(nbase);
246+
ModuleBase::ComplexMatrix vc_ev_vector(notconv, nbase);
255247
for (int m = 0; m < notconv; m++)
256248
{
257249
for(int i = 0; i < nbase; i++)
258250
{
259-
vc_ev_vector[i] = vc(i, unconv[m]);
251+
vc_ev_vector(m, i) = vc(i, unconv[m]);
260252
}
261-
int inc = 1;
262-
char trans = 'N';
263-
zgemv_(&trans,
264-
&npw,
265-
&nbase,
266-
&ModuleBase::ONE,
267-
hp.c,
268-
&hp.nc,
269-
vc_ev_vector.data(),
270-
&inc,
271-
&ModuleBase::ZERO,
272-
respsi,
273-
&inc);
274-
253+
}
254+
ppsi = &basis(nbase, 0);
255+
int inc = 1;
256+
char trans = 'N';
257+
char transb = 'N';
258+
zgemm_(&trans,
259+
&transb,
260+
&npw, // m: row of A,C
261+
&notconv, // n: col of B,C
262+
&nbase, // k: col of A, row of B
263+
&ModuleBase::ONE, // alpha
264+
hp.c, // A
265+
&hp.nc, // LDA: if(N) max(1,m) if(T) max(1,k)
266+
vc_ev_vector.c, // B
267+
&vc_ev_vector.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
268+
&ModuleBase::ZERO, // belta
269+
ppsi, // C
270+
&basis.get_nbasis()); // LDC: if(N) max(1, m)
271+
/*zgemv_(&trans,
272+
&npw,
273+
&nbase,
274+
&ModuleBase::ONE,
275+
hp.c,
276+
&hp.nc,
277+
vc_ev_vector.data(),
278+
&inc,
279+
&ModuleBase::ZERO,
280+
respsi,
281+
&inc);*/
282+
for (int m = 0; m < notconv; m++)
283+
{
275284
for(int i = 0; i < nbase; i++)
276285
{
277-
vc_ev_vector[i] *= -1 * eigenvalue[unconv[m]];
286+
vc_ev_vector(m, i) *= -1 * eigenvalue[unconv[m]];
278287
}
279-
zgemv_(&trans,
288+
}
289+
zgemm_(&trans,
290+
&transb,
291+
&npw, // m: row of A,C
292+
&notconv, // n: col of B,C
293+
&nbase, // k: col of A, row of B
294+
&ModuleBase::ONE, // alpha
295+
sp.c, // A
296+
&sp.nc, // LDA: if(N) max(1,m) if(T) max(1,k)
297+
vc_ev_vector.c, // B
298+
&vc_ev_vector.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
299+
&ModuleBase::ONE, // belta
300+
ppsi, // C
301+
&basis.get_nbasis()); // LDC: if(N) max(1, m)
302+
/*zgemv_(&trans,
280303
&npw,
281304
&nbase,
282305
&ModuleBase::ONE,
@@ -286,7 +309,7 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
286309
&inc,
287310
&ModuleBase::ONE,
288311
respsi,
289-
&inc);
312+
&inc);*/
290313

291314
/*ModuleBase::GlobalFunc::ZEROS(respsi, npw);
292315
for (int i = 0; i < nbase; i++)
@@ -301,11 +324,12 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
301324
}
302325
}*/
303326

304-
327+
for (int m = 0; m < notconv; m++)
328+
{
305329
ppsi = &basis(nbase + m, 0);
306330
for (int ig = 0; ig < npw; ig++)
307331
{
308-
ppsi[ig] = respsi[ig] / this->precondition[ig];
332+
ppsi[ig] /= this->precondition[ig];
309333
}
310334
}
311335

@@ -320,8 +344,8 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
320344
phm_in->sPsi(&basis(nbase + m, 0), &sp(nbase + m, 0), (size_t)npw);
321345
}
322346
//first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
323-
char trans = 'C';
324-
char transb = 'N';
347+
trans = 'C';
348+
transb = 'N';
325349
//calculate the square matrix for future lagranges
326350
zgemm_(&trans,
327351
&transb,

source/module_hsolver/diago_david.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class DiagoDavid : public DiagH
4242
ModuleBase::ComplexMatrix& sp,
4343
const ModuleBase::ComplexMatrix& vc,
4444
const int* unconv,
45-
const double* en,
46-
std::complex<double>* respsi);
45+
const double* en);
4746

4847
void cal_elem(const int& npw,
4948
int& nbase,

0 commit comments

Comments
 (0)