@@ -65,22 +65,41 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
6565
6666 ModuleBase::timer::tick (" DiagoDavid" , " first" );
6767 // orthogonalise the initial trial psi(0~nband-1)
68+ // plan for SchmitOrth
69+ ModuleBase::ComplexMatrix lagrange_matrix (nband, nband);
70+ std::vector<int > pre_matrix_mm_m (nband, 0 );
71+ std::vector<int > pre_matrix_mv_m (nband, 1 );
72+ this ->planSchmitOrth (nband, pre_matrix_mm_m.data (), pre_matrix_mv_m.data ());
73+ for ( int m = 0 ; m < nband; m++)
74+ {
75+ phm_in->sPsi (&psi (m, 0 ), &sp (m, 0 ), (size_t )dim);
76+ }
77+ // begin SchmitOrth
6878 for (int m = 0 ; m < nband; m++)
6979 {
7080 // psi_m = psi(m)
71- ModuleBase::GlobalFunc::COPYARRAY (&psi (m, 0 ), psi_m. data ( ), dim);
81+ ModuleBase::GlobalFunc::COPYARRAY (&psi (m, 0 ), & basis (m, 0 ), dim);
7282 /* for (int ig = 0; ig < dim; ig++)
7383 {
7484 psi_m[ig] = psi(m, ig);
7585 }*/
7686
77- phm_in->sPsi (psi_m.data (), spsi.data (), (size_t )dim);
78- this ->SchmitOrth (dim, nband, m, basis, psi_m.data (), spsi.data ());
79- phm_in->sPsi (psi_m.data (), spsi.data (), (size_t )dim);
87+ // phm_in->sPsi(psi_m.data(), spsi.data(), (size_t)dim);
88+ this ->SchmitOrth (
89+ dim,
90+ nband,
91+ m,
92+ basis,
93+ sp,
94+ &lagrange_matrix (m, 0 ),
95+ pre_matrix_mm_m[m],
96+ pre_matrix_mv_m[m]
97+ );
98+ phm_in->sPsi (&basis (m, 0 ), &sp (m, 0 ), (size_t )dim);
8099
81100 // basis(m) = psi_m, hp(m) = H |psi_m>, sp(m) = S |psi_m>
82- ModuleBase::GlobalFunc::COPYARRAY (psi_m.data (), &basis (m, 0 ), dim);
83- ModuleBase::GlobalFunc::COPYARRAY (spsi.data (), &sp (m, 0 ), dim);
101+ // ModuleBase::GlobalFunc::COPYARRAY(psi_m.data(), &basis(m, 0), dim);
102+ // ModuleBase::GlobalFunc::COPYARRAY(spsi.data(), &sp(m, 0), dim);
84103 /* std::complex<double>* sp_p = &sp(m, 0);
85104 std::complex<double>* basis_p = &basis(m, 0);
86105 for (int ig = 0; ig < dim; ig++)
@@ -90,6 +109,7 @@ void DiagoDavid::diag_mock(hamilt::Hamilt* phm_in, psi::Psi<std::complex<double>
90109 sp_p[ig] = spsi[ig];
91110 }*/
92111 }
112+ // end of SchmitOrth and calculate H|psi>
93113 hp_info dav_hpsi_in (&basis, psi::Range (1 , 0 , 0 , nband-1 ));
94114 auto hp_psi = std::get<0 >(phm_in->ops ->hPsi (dav_hpsi_in));
95115 ModuleBase::GlobalFunc::COPYARRAY (hp_psi->get_pointer (), &hp (0 , 0 ), hp_psi->get_nbasis () * nband);
@@ -283,23 +303,57 @@ void DiagoDavid::cal_grad(hamilt::Hamilt* phm_in,
283303
284304
285305 ppsi = &basis (nbase + m, 0 );
286- spsi = &sp (nbase + m, 0 );
287306 for (int ig = 0 ; ig < npw; ig++)
288307 {
289308 ppsi[ig] = respsi[ig] / this ->precondition [ig];
290309 }
310+ }
291311
292- phm_in->sPsi (ppsi, spsi, (size_t )npw);
293- this ->SchmitOrth (npw, nbase + notconv, nbase + m, basis, ppsi, spsi);
312+ // there is a nbase to nbase + notconv band orthogonalise
313+ // plan for SchmitOrth
314+ ModuleBase::ComplexMatrix lagrange_matrix (notconv, nbase + notconv);
315+ std::vector<int > pre_matrix_mm_m (notconv, 0 );
316+ std::vector<int > pre_matrix_mv_m (notconv, 1 );
317+ this ->planSchmitOrth (notconv, pre_matrix_mm_m.data (), pre_matrix_mv_m.data ());
318+ for ( int m = 0 ; m < notconv; m++)
319+ {
320+ phm_in->sPsi (&basis (nbase + m, 0 ), &sp (nbase + m, 0 ), (size_t )npw);
321+ }
322+ // first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix
323+ char trans = ' C' ;
324+ char transb = ' N' ;
325+ // calculate the square matrix for future lagranges
326+ zgemm_ (&trans,
327+ &transb,
328+ &nbase, // m: row of A,C
329+ ¬conv, // n: col of B,C
330+ &npw, // k: col of A, row of B
331+ &ModuleBase::ONE, // alpha
332+ &basis (0 , 0 ), // A
333+ &basis.get_nbasis (), // LDA: if(N) max(1,m) if(T) max(1,k)
334+ &sp (nbase, 0 ), // B
335+ &sp.nc , // LDB: if(N) max(1,k) if(T) max(1,n)
336+ &ModuleBase::ZERO, // belta
337+ &lagrange_matrix (0 , 0 ), // C
338+ &lagrange_matrix.nc ); // LDC: if(N) max(1, m)
339+
340+ for (int m = 0 ; m < notconv; m++)
341+ {
342+ ppsi = &basis (nbase + m, 0 );
343+ spsi = &sp (nbase + m, 0 );
344+
345+ this ->SchmitOrth (
346+ npw,
347+ nbase + notconv,
348+ nbase + m,
349+ basis,
350+ sp,
351+ &lagrange_matrix (m, 0 ),
352+ pre_matrix_mm_m[m],
353+ pre_matrix_mv_m[m]
354+ );
294355 phm_in->sPsi (ppsi, spsi, (size_t )npw);
295356
296-
297- // for (int ig = 0; ig < npw; ig++)
298- // {
299- // basis(nbase + m, ig) = ppsi[ig];
300- // hp(nbase + m, ig) = hpsi[ig];
301- // sp(nbase + m, ig) = spsi[ig];
302- // }
303357 }
304358 hp_info dav_hpsi_in (&basis, psi::Range (1 , 0 , nbase, nbase + notconv-1 ));
305359 auto hp_psi = std::get<0 >(phm_in->ops ->hPsi (dav_hpsi_in));
@@ -330,56 +384,6 @@ void DiagoDavid::cal_elem(const int &npw,
330384 // ModuleBase::GlobalFunc::ZEROS( hc.c+offset_h, notconv*hc.nr );
331385 // ModuleBase::GlobalFunc::ZEROS( sc.c+offset_s, notconv*sc.nr );
332386
333- /* for (int i = nbase; i < nbase + notconv; i++)
334- {
335- char trans1 = 'C';
336- char trans2 = 'N';
337- const int nb_notc = i+1;
338- const int one = 1;
339- hc = transpose(hc, false);
340- zgemm_(&trans1,
341- &trans2,
342- &one,
343- &nb_notc,
344- &npw,
345- &ModuleBase::ONE,
346- &basis(i , 0),
347- &basis.get_nbasis(),
348- hp.c,
349- &hp.nc,
350- &ModuleBase::ONE,
351- &hc(0, i),
352- &hc.nr);
353- hc = transpose(hc, false);
354-
355- sc = transpose(sc, false);
356- zgemm_(&trans1,
357- &trans2,
358- &one,
359- &nb_notc,
360- &npw,
361- &ModuleBase::ONE,
362- &basis(i, 0),
363- &basis.get_nbasis(),
364- sp.c,
365- &sp.nc,
366- &ModuleBase::ONE,
367- &sc(0, i),
368- &sc.nr);
369- sc = transpose(sc, false);
370- for (int j = 0; j <= i; j++)
371- {
372- for (int ig = 0; ig < npw; ig++)
373- {
374- hc(i, j) += conj(basis(i, ig)) * hp(j, ig);
375- sc(i, j) += conj(basis(i, ig)) * sp(j, ig);
376- }
377- // hc(j,i) = Diago_CG::ddot( npw, basis, j, hp, i );
378- // sc(j,i) = Diago_CG::ddot( npw, basis, j, sp, i );
379- }
380- std::cout<<__FILE__<<__LINE__<<" "<<i<<" "<<hc(i,0)<<" "<<hc(i,1)<<std::endl;
381- }*/
382-
383387 char trans1 = ' C' ;
384388 char trans2 = ' N' ;
385389 const int nb_notc = (nbase + notconv);
@@ -661,9 +665,11 @@ void DiagoDavid::cal_err(const int &npw,
661665void DiagoDavid::SchmitOrth (const int &npw,
662666 const int n_band,
663667 const int m,
664- const psi::Psi<std::complex <double >> &psi,
665- std::complex <double > *psi_m,
666- std::complex <double > *spsi)
668+ psi::Psi<std::complex <double >>& psi,
669+ const ModuleBase::ComplexMatrix& spsi,
670+ std::complex <double >* lagrange_m,
671+ const int mm_size,
672+ const int mv_size)
667673{
668674 // if(test_david == 1) ModuleBase::TITLE("DiagoDavid","SchmitOrth");
669675 ModuleBase::timer::tick (" DiagoDavid" , " SchmitOrth" );
@@ -678,22 +684,42 @@ void DiagoDavid::SchmitOrth(const int &npw,
678684 assert (m >= 0 );
679685 assert (m < n_band);
680686
681- std::complex <double > *lagrange = new std::complex <double >[m + 1 ];
682- ModuleBase::GlobalFunc::ZEROS (lagrange, m + 1 );
687+ std::complex <double >* psi_m = &psi (m, 0 );
688+
689+ // std::complex<double> *lagrange = new std::complex<double>[m + 1];
690+ // ModuleBase::GlobalFunc::ZEROS(lagrange, m + 1);
683691
684692 int inc = 1 ;
685- int mp = m;
686693 char trans = ' C' ;
694+ char transb = ' N' ;
695+ // calculate the square matrix for future lagranges
696+ if (mm_size != 0 )
697+ {
698+ zgemm_ (&trans,
699+ &transb,
700+ &mm_size, // m: row of A,C
701+ &mm_size, // n: col of B,C
702+ &npw, // k: col of A, row of B
703+ &ModuleBase::ONE, // alpha
704+ &psi (m-mv_size+1 -mm_size, 0 ), // A
705+ &psi.get_nbasis (), // LDA: if(N) max(1,m) if(T) max(1,k)
706+ &spsi (m, 0 ), // B
707+ &spsi.nc , // LDB: if(N) max(1,k) if(T) max(1,n)
708+ &ModuleBase::ZERO, // belta
709+ &lagrange_m[m-mv_size+1 -mm_size], // C
710+ &n_band); // LDC: if(N) max(1, m)
711+ }
712+ // calculate other lagranges for this band
687713 zgemv_ (&trans,
688714 &npw,
689- &mp ,
715+ &mv_size ,
690716 &ModuleBase::ONE,
691- psi. get_pointer ( ),
717+ & psi (m-mv_size+ 1 , 0 ),
692718 &psi.get_nbasis (),
693- spsi,
719+ & spsi (m, 0 ) ,
694720 &inc,
695721 &ModuleBase::ZERO,
696- lagrange ,
722+ &lagrange_m[m-mv_size+ 1 ] ,
697723 &inc);
698724 /* for (int j = 0; j < m; j++)
699725 {
@@ -705,30 +731,30 @@ void DiagoDavid::SchmitOrth(const int &npw,
705731 }
706732 // lagrange[j] = Diago_CG::ddot( npw, psi, j, spsi );
707733 }*/
708- zdotc_ (&lagrange[m], &npw, psi_m, &inc, spsi, &inc);
734+ // zdotc_(&lagrange[m], &npw, psi_m, &inc, spsi, &inc);
709735 /* for (int ig = 0; ig < npw; ig++)
710736 {
711737 lagrange[m] += conj(psi_m[ig]) * spsi[ig];
712738 }*/
713739 // lagrange[m] = Diago_CG::ddot( npw, psi_m, spsi );
714740
715- Parallel_Reduce::reduce_complex_double_pool (lagrange , m + 1 );
741+ Parallel_Reduce::reduce_complex_double_pool (lagrange_m , m + 1 );
716742
717743 // out.printr1_d("lagrange", lagrange, m+1 );
718744
719- double psi_norm = lagrange [m].real ();
745+ double psi_norm = lagrange_m [m].real ();
720746 assert (psi_norm > 0.0 );
721747 // std::cout << "m = " << m << std::endl;
722748
723749 for (int j = 0 ; j < m; j++)
724750 {
725- const std::complex <double > alpha = std::complex <double >(-1 , 0 ) * lagrange [j];
751+ const std::complex <double > alpha = std::complex <double >(-1 , 0 ) * lagrange_m [j];
726752 zaxpy_ (&npw, &alpha, &psi (j,0 ), &inc, psi_m, &inc);
727753 /* for (int ig = 0; ig < npw; ig++)
728754 {
729755 psi_m[ig] -= lagrange[j] * psi(j, ig);
730756 }*/
731- psi_norm -= (conj (lagrange [j]) * lagrange [j]).real ();
757+ psi_norm -= (conj (lagrange_m [j]) * lagrange_m [j]).real ();
732758 }
733759
734760 assert (psi_norm > 0.0 );
@@ -750,11 +776,72 @@ void DiagoDavid::SchmitOrth(const int &npw,
750776 }
751777 }
752778
753- delete[] lagrange;
779+ // delete[] lagrange;
754780 ModuleBase::timer::tick (" DiagoDavid" , " SchmitOrth" );
755781 return ;
756782}
757783
784+ void DiagoDavid::planSchmitOrth (
785+ const int nband,
786+ int * pre_matrix_mm_m,
787+ int * pre_matrix_mv_m
788+ )
789+ {
790+ if (nband<=0 )return ;
791+ ModuleBase::GlobalFunc::ZEROS (pre_matrix_mm_m, nband);
792+ ModuleBase::GlobalFunc::ZEROS (pre_matrix_mv_m, nband);
793+ int last_matrix_size = nband;
794+ int matrix_size = int (nband / 2 );
795+ int divide_times = 0 ;
796+ std::vector<int > divide_points (nband);
797+ int res_nband = nband - matrix_size;
798+ while (matrix_size>1 )
799+ {
800+ int index = nband - matrix_size;
801+ if (divide_times == 0 )
802+ {
803+ divide_points[0 ] = index;
804+ pre_matrix_mm_m[index] = matrix_size;
805+ if (res_nband == matrix_size) pre_matrix_mv_m[index] = 1 ;
806+ else pre_matrix_mv_m[index] = 2 ;
807+ divide_times = 1 ;
808+ }
809+ else
810+ {
811+ for (int i=divide_times-1 ; i>=0 ; i--)
812+ {
813+ divide_points[i*2 ] = divide_points[i] - matrix_size;
814+ divide_points[i*2 +1 ] = divide_points[i*2 ] + last_matrix_size;
815+ pre_matrix_mm_m[ divide_points[i*2 ] ] = matrix_size;
816+ pre_matrix_mm_m[ divide_points[i*2 +1 ]] = matrix_size;
817+ if (res_nband == matrix_size)
818+ {
819+ pre_matrix_mv_m[divide_points[i*2 ]] = 1 ;
820+ pre_matrix_mv_m[divide_points[i*2 +1 ]] = 1 ;
821+ }
822+ else
823+ {
824+ pre_matrix_mv_m[divide_points[i*2 ]] = 2 ;
825+ pre_matrix_mv_m[divide_points[i*2 +1 ]] = 2 ;
826+ }
827+ }
828+ divide_times *= 2 ;
829+ }
830+ last_matrix_size = matrix_size;
831+ matrix_size = int (res_nband / 2 );
832+ res_nband -= matrix_size;
833+ }
834+ // fill the pre_matrix_mv_m array
835+ pre_matrix_mv_m[0 ] = 1 ;
836+ for (int m = 1 ; m < nband; m++)
837+ {
838+ if (pre_matrix_mv_m[m] == 0 )
839+ {
840+ pre_matrix_mv_m[m] = pre_matrix_mv_m[m-1 ]+1 ;
841+ }
842+ }
843+ }
844+
758845void DiagoDavid::diag (hamilt::Hamilt *phm_in, psi::Psi<std::complex <double >> &psi, double *eigenvalue_in)
759846{
760847 // / record the times of trying iterative diagonalization
0 commit comments