Skip to content

Commit 8034a77

Browse files
committed
Perf: add dynamic planning algorithm for SchmitOrth in davidson method
1 parent 7b7bd7b commit 8034a77

File tree

2 files changed

+178
-85
lines changed

2 files changed

+178
-85
lines changed

source/module_hsolver/diago_david.cpp

Lines changed: 169 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,41 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
6565

6666
ModuleBase::timer::tick("DiagoDavid", "first");
6767
// orthogonalise the initial trial psi(0~nband-1)
68+
//plan for SchmitOrth
69+
ModuleBase::ComplexMatrix lagrange_matrix(nband, nband);
70+
std::vector<int> pre_matrix_mm_m(nband, 0);
71+
std::vector<int> pre_matrix_mv_m(nband, 1);
72+
this->planSchmitOrth(nband, pre_matrix_mm_m.data(), pre_matrix_mv_m.data());
73+
for( int m = 0; m < nband; m++)
74+
{
75+
phm_in->sPsi(&psi(m, 0), &sp(m, 0), (size_t)dim);
76+
}
77+
//begin SchmitOrth
6878
for (int m = 0; m < nband; m++)
6979
{
7080
// psi_m = psi(m)
71-
ModuleBase::GlobalFunc::COPYARRAY(&psi(m, 0), psi_m.data(), dim);
81+
ModuleBase::GlobalFunc::COPYARRAY(&psi(m, 0), &basis(m, 0), dim);
7282
/*for (int ig = 0; ig < dim; ig++)
7383
{
7484
psi_m[ig] = psi(m, ig);
7585
}*/
7686

77-
phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
78-
this->SchmitOrth(dim, nband, m, basis, psi_m.data(), spsi.data());
79-
phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
87+
//phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
88+
this->SchmitOrth(
89+
dim,
90+
nband,
91+
m,
92+
basis,
93+
sp,
94+
&lagrange_matrix(m, 0),
95+
pre_matrix_mm_m[m],
96+
pre_matrix_mv_m[m]
97+
);
98+
phm_in->sPsi(&basis(m, 0), &sp(m, 0), (size_t)dim);
8099

81100
// basis(m) = psi_m, hp(m) = H |psi_m>, sp(m) = S |psi_m>
82-
ModuleBase::GlobalFunc::COPYARRAY(psi_m.data(), &basis(m, 0), dim);
83-
ModuleBase::GlobalFunc::COPYARRAY(spsi.data(), &sp(m, 0), dim);
101+
//ModuleBase::GlobalFunc::COPYARRAY(psi_m.data(), &basis(m, 0), dim);
102+
//ModuleBase::GlobalFunc::COPYARRAY(spsi.data(), &sp(m, 0), dim);
84103
/*std::complex<double>* sp_p = &sp(m, 0);
85104
std::complex<double>* basis_p = &basis(m, 0);
86105
for (int ig = 0; ig < dim; ig++)
@@ -90,6 +109,7 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
90109
sp_p[ig] = spsi[ig];
91110
}*/
92111
}
112+
//end of SchmitOrth and calculate H|psi>
93113
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, 0, nband-1));
94114
auto hp_psi = std::get<0>(phm_in->ops->hPsi(dav_hpsi_in));
95115
ModuleBase::GlobalFunc::COPYARRAY(hp_psi->get_pointer(), &hp(0, 0), hp_psi->get_nbasis() * nband);
@@ -283,23 +303,57 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
283303

284304

285305
ppsi = &basis(nbase + m, 0);
286-
spsi = &sp(nbase + m, 0);
287306
for (int ig = 0; ig < npw; ig++)
288307
{
289308
ppsi[ig] = respsi[ig] / this->precondition[ig];
290309
}
310+
}
291311

292-
phm_in->sPsi(ppsi, spsi, (size_t)npw);
293-
this->SchmitOrth(npw, nbase + notconv, nbase + m, basis, ppsi, spsi);
312+
//there is a nbase to nbase + notconv band orthogonalise
313+
//plan for SchmitOrth
314+
ModuleBase::ComplexMatrix lagrange_matrix(notconv, nbase + notconv);
315+
std::vector<int> pre_matrix_mm_m(notconv, 0);
316+
std::vector<int> pre_matrix_mv_m(notconv, 1);
317+
this->planSchmitOrth(notconv, pre_matrix_mm_m.data(), pre_matrix_mv_m.data());
318+
for( int m = 0; m < notconv; m++)
319+
{
320+
phm_in->sPsi(&basis(nbase + m, 0), &sp(nbase + m, 0), (size_t)npw);
321+
}
322+
//first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
323+
char trans = 'C';
324+
char transb = 'N';
325+
//calculate the square matrix for future lagranges
326+
zgemm_(&trans,
327+
&transb,
328+
&nbase, // m: row of A,C
329+
&notconv, // n: col of B,C
330+
&npw, // k: col of A, row of B
331+
&ModuleBase::ONE, // alpha
332+
&basis(0, 0), // A
333+
&basis.get_nbasis(), // LDA: if(N) max(1,m) if(T) max(1,k)
334+
&sp(nbase, 0), // B
335+
&sp.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
336+
&ModuleBase::ZERO, // belta
337+
&lagrange_matrix(0, 0), // C
338+
&lagrange_matrix.nc); // LDC: if(N) max(1, m)
339+
340+
for (int m = 0; m < notconv; m++)
341+
{
342+
ppsi = &basis(nbase + m, 0);
343+
spsi = &sp(nbase + m, 0);
344+
345+
this->SchmitOrth(
346+
npw,
347+
nbase + notconv,
348+
nbase + m,
349+
basis,
350+
sp,
351+
&lagrange_matrix(m, 0),
352+
pre_matrix_mm_m[m],
353+
pre_matrix_mv_m[m]
354+
);
294355
phm_in->sPsi(ppsi, spsi, (size_t)npw);
295356

296-
297-
//for (int ig = 0; ig < npw; ig++)
298-
//{
299-
//basis(nbase + m, ig) = ppsi[ig];
300-
//hp(nbase + m, ig) = hpsi[ig];
301-
//sp(nbase + m, ig) = spsi[ig];
302-
//}
303357
}
304358
hp_info dav_hpsi_in(&basis, psi::Range(1, 0, nbase, nbase + notconv-1));
305359
auto hp_psi = std::get<0>(phm_in->ops->hPsi(dav_hpsi_in));
@@ -330,56 +384,6 @@ void DiagoDavid::cal_elem(const int &npw,
330384
// ModuleBase::GlobalFunc::ZEROS( hc.c+offset_h, notconv*hc.nr );
331385
// ModuleBase::GlobalFunc::ZEROS( sc.c+offset_s, notconv*sc.nr );
332386

333-
/*for (int i = nbase; i < nbase + notconv; i++)
334-
{
335-
char trans1 = 'C';
336-
char trans2 = 'N';
337-
const int nb_notc = i+1;
338-
const int one = 1;
339-
hc = transpose(hc, false);
340-
zgemm_(&trans1,
341-
&trans2,
342-
&one,
343-
&nb_notc,
344-
&npw,
345-
&ModuleBase::ONE,
346-
&basis(i , 0),
347-
&basis.get_nbasis(),
348-
hp.c,
349-
&hp.nc,
350-
&ModuleBase::ONE,
351-
&hc(0, i),
352-
&hc.nr);
353-
hc = transpose(hc, false);
354-
355-
sc = transpose(sc, false);
356-
zgemm_(&trans1,
357-
&trans2,
358-
&one,
359-
&nb_notc,
360-
&npw,
361-
&ModuleBase::ONE,
362-
&basis(i, 0),
363-
&basis.get_nbasis(),
364-
sp.c,
365-
&sp.nc,
366-
&ModuleBase::ONE,
367-
&sc(0, i),
368-
&sc.nr);
369-
sc = transpose(sc, false);
370-
for (int j = 0; j <= i; j++)
371-
{
372-
for (int ig = 0; ig < npw; ig++)
373-
{
374-
hc(i, j) += conj(basis(i, ig)) * hp(j, ig);
375-
sc(i, j) += conj(basis(i, ig)) * sp(j, ig);
376-
}
377-
// hc(j,i) = Diago_CG::ddot( npw, basis, j, hp, i );
378-
// sc(j,i) = Diago_CG::ddot( npw, basis, j, sp, i );
379-
}
380-
std::cout<<__FILE__<<__LINE__<<" "<<i<<" "<<hc(i,0)<<" "<<hc(i,1)<<std::endl;
381-
}*/
382-
383387
char trans1 = 'C';
384388
char trans2 = 'N';
385389
const int nb_notc = (nbase + notconv);
@@ -661,9 +665,11 @@ void DiagoDavid::cal_err(const int &npw,
661665
void DiagoDavid::SchmitOrth(const int &npw,
662666
const int n_band,
663667
const int m,
664-
const psi::Psi<std::complex<double>> &psi,
665-
std::complex<double> *psi_m,
666-
std::complex<double> *spsi)
668+
psi::Psi<std::complex<double>>& psi,
669+
const ModuleBase::ComplexMatrix& spsi,
670+
std::complex<double>* lagrange_m,
671+
const int mm_size,
672+
const int mv_size)
667673
{
668674
// if(test_david == 1) ModuleBase::TITLE("DiagoDavid","SchmitOrth");
669675
ModuleBase::timer::tick("DiagoDavid", "SchmitOrth");
@@ -678,22 +684,42 @@ void DiagoDavid::SchmitOrth(const int &npw,
678684
assert(m >= 0);
679685
assert(m < n_band);
680686

681-
std::complex<double> *lagrange = new std::complex<double>[m + 1];
682-
ModuleBase::GlobalFunc::ZEROS(lagrange, m + 1);
687+
std::complex<double>* psi_m = &psi(m, 0);
688+
689+
//std::complex<double> *lagrange = new std::complex<double>[m + 1];
690+
//ModuleBase::GlobalFunc::ZEROS(lagrange, m + 1);
683691

684692
int inc = 1;
685-
int mp = m;
686693
char trans = 'C';
694+
char transb = 'N';
695+
//calculate the square matrix for future lagranges
696+
if(mm_size != 0)
697+
{
698+
zgemm_(&trans,
699+
&transb,
700+
&mm_size, // m: row of A,C
701+
&mm_size, // n: col of B,C
702+
&npw, // k: col of A, row of B
703+
&ModuleBase::ONE, // alpha
704+
&psi(m-mv_size+1-mm_size, 0), // A
705+
&psi.get_nbasis(), // LDA: if(N) max(1,m) if(T) max(1,k)
706+
&spsi(m, 0), // B
707+
&spsi.nc, // LDB: if(N) max(1,k) if(T) max(1,n)
708+
&ModuleBase::ZERO, // belta
709+
&lagrange_m[m-mv_size+1-mm_size], // C
710+
&n_band); // LDC: if(N) max(1, m)
711+
}
712+
//calculate other lagranges for this band
687713
zgemv_(&trans,
688714
&npw,
689-
&mp,
715+
&mv_size,
690716
&ModuleBase::ONE,
691-
psi.get_pointer(),
717+
&psi(m-mv_size+1, 0),
692718
&psi.get_nbasis(),
693-
spsi,
719+
&spsi(m, 0),
694720
&inc,
695721
&ModuleBase::ZERO,
696-
lagrange,
722+
&lagrange_m[m-mv_size+1],
697723
&inc);
698724
/*for (int j = 0; j < m; j++)
699725
{
@@ -705,30 +731,30 @@ void DiagoDavid::SchmitOrth(const int &npw,
705731
}
706732
// lagrange[j] = Diago_CG::ddot( npw, psi, j, spsi );
707733
}*/
708-
zdotc_(&lagrange[m], &npw, psi_m, &inc, spsi, &inc);
734+
//zdotc_(&lagrange[m], &npw, psi_m, &inc, spsi, &inc);
709735
/*for (int ig = 0; ig < npw; ig++)
710736
{
711737
lagrange[m] += conj(psi_m[ig]) * spsi[ig];
712738
}*/
713739
// lagrange[m] = Diago_CG::ddot( npw, psi_m, spsi );
714740

715-
Parallel_Reduce::reduce_complex_double_pool(lagrange, m + 1);
741+
Parallel_Reduce::reduce_complex_double_pool(lagrange_m, m + 1);
716742

717743
// out.printr1_d("lagrange", lagrange, m+1 );
718744

719-
double psi_norm = lagrange[m].real();
745+
double psi_norm = lagrange_m[m].real();
720746
assert(psi_norm > 0.0);
721747
// std::cout << "m = " << m << std::endl;
722748

723749
for (int j = 0; j < m; j++)
724750
{
725-
const std::complex<double> alpha = std::complex<double>(-1, 0) * lagrange[j];
751+
const std::complex<double> alpha = std::complex<double>(-1, 0) * lagrange_m[j];
726752
zaxpy_(&npw, &alpha, &psi(j,0), &inc, psi_m, &inc);
727753
/*for (int ig = 0; ig < npw; ig++)
728754
{
729755
psi_m[ig] -= lagrange[j] * psi(j, ig);
730756
}*/
731-
psi_norm -= (conj(lagrange[j]) * lagrange[j]).real();
757+
psi_norm -= (conj(lagrange_m[j]) * lagrange_m[j]).real();
732758
}
733759

734760
assert(psi_norm > 0.0);
@@ -750,11 +776,72 @@ void DiagoDavid::SchmitOrth(const int &npw,
750776
}
751777
}
752778

753-
delete[] lagrange;
779+
//delete[] lagrange;
754780
ModuleBase::timer::tick("DiagoDavid", "SchmitOrth");
755781
return;
756782
}
757783

784+
void DiagoDavid::planSchmitOrth(
785+
const int nband,
786+
int* pre_matrix_mm_m,
787+
int* pre_matrix_mv_m
788+
)
789+
{
790+
if(nband<=0)return;
791+
ModuleBase::GlobalFunc::ZEROS(pre_matrix_mm_m, nband);
792+
ModuleBase::GlobalFunc::ZEROS(pre_matrix_mv_m, nband);
793+
int last_matrix_size = nband;
794+
int matrix_size = int(nband / 2);
795+
int divide_times = 0;
796+
std::vector<int> divide_points(nband);
797+
int res_nband = nband - matrix_size;
798+
while(matrix_size>1)
799+
{
800+
int index = nband - matrix_size;
801+
if(divide_times == 0)
802+
{
803+
divide_points[0] = index;
804+
pre_matrix_mm_m[index] = matrix_size;
805+
if(res_nband == matrix_size) pre_matrix_mv_m[index] = 1;
806+
else pre_matrix_mv_m[index] = 2;
807+
divide_times = 1;
808+
}
809+
else
810+
{
811+
for(int i=divide_times-1; i>=0; i--)
812+
{
813+
divide_points[i*2] = divide_points[i] - matrix_size;
814+
divide_points[i*2+1] = divide_points[i*2] + last_matrix_size;
815+
pre_matrix_mm_m[ divide_points[i*2] ] = matrix_size;
816+
pre_matrix_mm_m[ divide_points[i*2+1]] = matrix_size;
817+
if(res_nband == matrix_size)
818+
{
819+
pre_matrix_mv_m[divide_points[i*2]] = 1;
820+
pre_matrix_mv_m[divide_points[i*2+1]] = 1;
821+
}
822+
else
823+
{
824+
pre_matrix_mv_m[divide_points[i*2]] = 2;
825+
pre_matrix_mv_m[divide_points[i*2+1]] = 2;
826+
}
827+
}
828+
divide_times *= 2;
829+
}
830+
last_matrix_size = matrix_size;
831+
matrix_size = int(res_nband / 2);
832+
res_nband -= matrix_size;
833+
}
834+
//fill the pre_matrix_mv_m array
835+
pre_matrix_mv_m[0] = 1;
836+
for(int m = 1; m < nband; m++)
837+
{
838+
if(pre_matrix_mv_m[m] == 0)
839+
{
840+
pre_matrix_mv_m[m] = pre_matrix_mv_m[m-1]+1;
841+
}
842+
}
843+
}
844+
758845
void DiagoDavid::diag(hamilt::Hamilt *phm_in, psi::Psi<std::complex<double>> &psi, double *eigenvalue_in)
759846
{
760847
/// record the times of trying iterative diagonalization

source/module_hsolver/diago_david.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,15 @@ class DiagoDavid : public DiagH
7878
void SchmitOrth(const int& npw,
7979
const int n_band,
8080
const int m,
81-
const psi::Psi<std::complex<double>>& psi,
82-
std::complex<double>* psi_m,
83-
std::complex<double>* spsi);
81+
psi::Psi<std::complex<double>>& psi,
82+
const ModuleBase::ComplexMatrix& spsi,
83+
std::complex<double>* lagrange_m,
84+
const int mm_size,
85+
const int mv_size);
86+
void planSchmitOrth(
87+
const int nband,
88+
int* pre_matrix_mm_m,
89+
int* pre_matrix_mv_m);
8490

8591
void diag_zhegvx(const int& n,
8692
const int& m,

0 commit comments

Comments
 (0)