@@ -54,6 +54,10 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
54
54
resmem_complex_op ()(this ->hphi , this ->nbase_x * this ->dim , " DAV::hphi" );
55
55
setmem_complex_op ()(this ->hphi , 0 , this ->nbase_x * this ->dim );
56
56
57
+ // the product of S and psi in the reduced psi set
58
+ resmem_complex_op ()(this ->sphi , this ->nbase_x * this ->dim , " DAV::sphi" );
59
+ setmem_complex_op ()(this ->sphi , 0 , this ->nbase_x * this ->dim );
60
+
57
61
// Hamiltonian on the reduced psi set
58
62
resmem_complex_op ()(this ->hcc , this ->nbase_x * this ->nbase_x , " DAV::hcc" );
59
63
setmem_complex_op ()(this ->hcc , 0 , this ->nbase_x * this ->nbase_x );
@@ -96,6 +100,7 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
96
100
97
101
template <typename T, typename Device>
98
102
int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
103
+ const HPsiFunc& spsi_func,
99
104
T* psi_in,
100
105
const int psi_in_dmax,
101
106
Real* eigenvalue_in_hsolver,
@@ -134,7 +139,11 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
134
139
// hphi[:, 0:nbase_x] = H * psi_in_iter[:, 0:nbase_x]
135
140
hpsi_func (this ->psi_in_iter , this ->hphi , this ->dim , this ->notconv );
136
141
137
- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
142
+ // compute s*psi_in_iter
143
+ // sphi[:, 0:nbase_x] = S * psi_in_iter[:, 0:nbase_x]
144
+ spsi_func (this ->psi_in_iter , this ->sphi , this ->dim , this ->notconv );
145
+
146
+ this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->sphi , this ->hphi , this ->hcc , this ->scc );
138
147
139
148
this ->diag_zhegvx (nbase, this ->notconv , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
140
149
@@ -152,16 +161,25 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
152
161
dav_iter++;
153
162
154
163
this ->cal_grad (hpsi_func,
164
+ spsi_func,
155
165
this ->dim ,
156
166
nbase,
157
167
this ->notconv ,
158
168
this ->psi_in_iter ,
159
169
this ->hphi ,
170
+ this ->sphi ,
160
171
this ->vcc ,
161
172
unconv.data (),
162
173
&eigenvalue_iter);
163
174
164
- this ->cal_elem (this ->dim , nbase, this ->notconv , this ->psi_in_iter , this ->hphi , this ->hcc , this ->scc );
175
+ this ->cal_elem (this ->dim ,
176
+ nbase,
177
+ this ->notconv ,
178
+ this ->psi_in_iter ,
179
+ this ->sphi ,
180
+ this ->hphi ,
181
+ this ->hcc ,
182
+ this ->scc );
165
183
166
184
this ->diag_zhegvx (nbase, this ->n_band , this ->hcc , this ->scc , this ->nbase_x , &eigenvalue_iter, this ->vcc );
167
185
@@ -238,6 +256,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
238
256
eigenvalue_in_hsolver,
239
257
this ->psi_in_iter ,
240
258
this ->hphi ,
259
+ this ->sphi ,
241
260
this ->hcc ,
242
261
this ->scc ,
243
262
this ->vcc );
@@ -255,11 +274,13 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
255
274
256
275
template <typename T, typename Device>
257
276
void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
277
+ const HPsiFunc& spsi_func,
258
278
const int & dim,
259
279
const int & nbase,
260
280
const int & notconv,
261
281
T* psi_iter,
262
282
T* hphi,
283
+ T* spsi,
263
284
T* vcc,
264
285
const int * unconv,
265
286
std::vector<Real>* eigenvalue_iter)
@@ -331,7 +352,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
331
352
notconv,
332
353
nbase,
333
354
this ->one ,
334
- psi_iter ,
355
+ sphi ,
335
356
this ->dim ,
336
357
vcc,
337
358
this ->nbase_x ,
@@ -396,6 +417,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
396
417
// update hpsi[:, nbase:nbase+notconv]
397
418
// hpsi[:, nbase:nbase+notconv] = H * psi_iter[:, nbase:nbase+notconv]
398
419
hpsi_func (psi_iter + nbase * dim, hphi + nbase * this ->dim , this ->dim , notconv);
420
+ spsi_func (psi_iter + nbase * dim, sphi + nbase * this ->dim , this ->dim , notconv);
399
421
400
422
ModuleBase::timer::tick (" Diago_DavSubspace" , " cal_grad" );
401
423
return ;
@@ -406,6 +428,7 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
406
428
int & nbase,
407
429
const int & notconv,
408
430
const T* psi_iter,
431
+ const T* spsi,
409
432
const T* hphi,
410
433
T* hcc,
411
434
T* scc)
@@ -416,39 +439,39 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
416
439
ModuleBase::gemm_op_mt<T, Device>()
417
440
#else
418
441
ModuleBase::gemm_op<T, Device>()
419
- #endif
420
- (' C' ,
421
- ' N' ,
422
- nbase + notconv,
423
- notconv,
424
- this ->dim ,
425
- this ->one ,
426
- psi_iter,
427
- this ->dim ,
428
- &hphi[nbase * this ->dim ],
429
- this ->dim ,
430
- this ->zero ,
431
- &hcc[nbase * this ->nbase_x ],
432
- this ->nbase_x );
442
+ #endif
443
+ (' C' ,
444
+ ' N' ,
445
+ nbase + notconv,
446
+ notconv,
447
+ this ->dim ,
448
+ this ->one ,
449
+ psi_iter,
450
+ this ->dim ,
451
+ &hphi[nbase * this ->dim ],
452
+ this ->dim ,
453
+ this ->zero ,
454
+ &hcc[nbase * this ->nbase_x ],
455
+ this ->nbase_x );
433
456
434
457
#ifdef __DSP
435
458
ModuleBase::gemm_op_mt<T, Device>()
436
459
#else
437
460
ModuleBase::gemm_op<T, Device>()
438
461
#endif
439
- (' C' ,
440
- ' N' ,
441
- nbase + notconv,
442
- notconv,
443
- this ->dim ,
444
- this ->one ,
445
- psi_iter,
446
- this ->dim ,
447
- psi_iter + nbase * this ->dim ,
448
- this ->dim ,
449
- this ->zero ,
450
- &scc[nbase * this ->nbase_x ],
451
- this ->nbase_x );
462
+ (' C' ,
463
+ ' N' ,
464
+ nbase + notconv,
465
+ notconv,
466
+ this ->dim ,
467
+ this ->one ,
468
+ psi_iter,
469
+ this ->dim ,
470
+ spsi + nbase * this ->dim ,
471
+ this ->dim ,
472
+ this ->zero ,
473
+ &scc[nbase * this ->nbase_x ],
474
+ this ->nbase_x );
452
475
453
476
#ifdef __MPI
454
477
if (this ->diag_comm .nproc > 1 )
@@ -685,10 +708,11 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
685
708
const Real* eigenvalue_in_hsolver,
686
709
// const psi::Psi<T, Device>& psi,
687
710
T* psi_iter,
688
- T* hp,
689
- T* sp,
690
- T* hc,
691
- T* vc)
711
+ T* hphi,
712
+ T* sphi,
713
+ T* hcc,
714
+ T* scc,
715
+ T* vcc)
692
716
{
693
717
ModuleBase::timer::tick (" Diago_DavSubspace" , " refresh" );
694
718
@@ -714,6 +738,28 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
714
738
// update hphi
715
739
syncmem_complex_op ()(hphi, psi_iter + nband * this ->dim , this ->dim * nband);
716
740
741
+ #ifdef __DSP
742
+ ModuleBase::gemm_op_mt<T, Device>()
743
+ #else
744
+ ModuleBase::gemm_op<T, Device>()
745
+ #endif
746
+ (' N' ,
747
+ ' N' ,
748
+ this ->dim ,
749
+ nband,
750
+ nbase,
751
+ this ->one ,
752
+ this ->sphi ,
753
+ this ->dim ,
754
+ this ->vcc ,
755
+ this ->nbase_x ,
756
+ this ->zero ,
757
+ psi_iter + nband * this ->dim ,
758
+ this ->dim );
759
+
760
+ // update sphi
761
+ syncmem_complex_op ()(sphi, psi_iter + nband * this ->dim , this ->dim * nband);
762
+
717
763
nbase = nband;
718
764
719
765
// set hcc/scc/vcc to 0
@@ -776,6 +822,7 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
776
822
777
823
template <typename T, typename Device>
778
824
int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
825
+ const HPsiFunc& spsi_func,
779
826
T* psi_in,
780
827
const int psi_in_dmax,
781
828
Real* eigenvalue_in_hsolver,
@@ -791,7 +838,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,
791
838
do
792
839
{
793
840
794
- sum_iter += this ->diag_once (hpsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
841
+ sum_iter += this ->diag_once (hpsi_func, spsi_func, psi_in, psi_in_dmax, eigenvalue_in_hsolver, ethr_band);
795
842
796
843
++ntry;
797
844
0 commit comments