@@ -52,52 +52,82 @@ void Veff<OperatorPW<T, Device>>::act(
5252 {
5353 setmem_complex_op ()(tmhpsi, 0 , nbasis*nbands/npol);
5454 }
55-
5655 int max_npw = nbasis / npol;
5756 const int current_spin = this ->isk [this ->ik ];
58-
57+ const int psi_offset= max_npw * npol;
5958#ifdef __DSP
60- wfcpw->fft_bundle .resource_handler (1 );
61- #endif
62-
63- for (int ib = 0 ; ib < nbands; ib += npol)
59+ if (npol == 1 )
60+ {
61+ ModulePW::FFT_Guard guard (wfcpw->fft_bundle );
62+ for (int ib = 0 ; ib < nbands; ib += npol)
63+ {
64+ wfcpw->convolution (this ->ctx ,
65+ this ->ik ,
66+ this ->veff_col ,
67+ tmpsi_in,
68+ this ->veff + current_spin * this ->veff_col ,
69+ tmhpsi,
70+ true );
71+ tmhpsi += psi_offset;
72+ tmpsi_in += psi_offset;
73+ }
74+ }else if (npol == 2 )
75+ {
76+ const Real* current_veff[4 ]={nullptr };
77+ for (int is = 0 ; is < 4 ; is++)
78+ {
79+ current_veff[is] = this ->veff + is * this ->veff_col ;
80+ }
81+ for (int ib = 0 ; ib < nbands; ib += npol)
82+ {
83+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
84+ wfcpw->recip_to_real <T, Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
85+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
86+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
87+ wfcpw->real_to_recip <T, Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
88+ tmhpsi += psi_offset;
89+ tmpsi_in += psi_offset;
90+ }
91+ }else {
92+ ModuleBase::WARNING_QUIT (" VeffPW" , " npol should be 1 or 2 or veff_col equal to 0\n " );
93+ }
94+ #else
95+ if (npol == 1 )
6496 {
65- if (npol == 1 )
97+ for ( int ib = 0 ; ib < nbands; ib += npol )
6698 {
67- wfcpw->recip_to_real <T,Device>(tmpsi_in, this ->porter , this ->ik );
99+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
68100 // NOTICE: when MPI threads are larger than the number of Z grids
69101 // veff would contain nothing, and nothing should be done in real space
70102 // but the 3DFFT can not be skipped, it will cause hanging
71- if (this ->veff_col != 0 )
72- {
73- veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->veff + current_spin * this ->veff_col );
74- }
75- wfcpw->real_to_recip <T,Device>(this ->porter , tmhpsi, this ->ik , true );
103+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->veff + current_spin * this ->veff_col );
104+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
105+ tmhpsi += psi_offset;
106+ tmpsi_in += psi_offset;
76107 }
77- else
108+ }
109+ else if (npol == 2 )
110+ {
111+ const Real* current_veff[4 ]={nullptr };
112+ for (int is = 0 ; is < 4 ; is++)
113+ {
114+ current_veff[is] = this ->veff + is * this ->veff_col ;
115+ }
116+ for (int ib = 0 ; ib < nbands; ib += npol)
78117 {
79118 // FFT to real space and do things.
80- wfcpw->recip_to_real <T,Device>(tmpsi_in, this ->porter , this ->ik );
81- wfcpw->recip_to_real <T,Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
82- if (this ->veff_col != 0 )
83- {
84- // / denghui added at 20221109
85- const Real* current_veff[4 ];
86- for (int is = 0 ; is < 4 ; is++)
87- {
88- current_veff[is] = this ->veff + is * this ->veff_col ; // for CPU device
89- }
90- veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
91- }
119+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
120+ wfcpw->recip_to_real <T, Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
121+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
92122 // FFT back to G space.
93- wfcpw->real_to_recip <T,Device>(this ->porter , tmhpsi, this ->ik , true );
94- wfcpw->real_to_recip <T,Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
123+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
124+ wfcpw->real_to_recip <T, Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
125+ tmhpsi += psi_offset;
126+ tmpsi_in += psi_offset;
95127 }
96- tmhpsi += max_npw * npol;
97- tmpsi_in += max_npw * npol ;
128+ } else {
129+ ModuleBase::WARNING_QUIT ( " VeffPW " , " npol should be 1 or 2 or veff_col equal to 0 \n " ) ;
98130 }
99- #ifdef __DSP
100- wfcpw->fft_bundle .resource_handler (0 );
101131#endif
102132 ModuleBase::timer::tick (" Operator" , " veff_pw" );
103133}
0 commit comments