@@ -55,76 +55,82 @@ void Veff<OperatorPW<T, Device>>::act(
5555
5656 int max_npw = nbasis / npol;
5757 const int current_spin = this ->isk [this ->ik ];
58-
58+ const int psi_offset= max_npw * npol;
5959#ifdef __DSP
60- ModulePW::FFT_Guard guard (wfcpw->fft_bundle );
61- for (int ib = 0 ; ib<nbands ; ib += npol)
60+ if (npol == 1 && this ->veff_col !=0 )
6261 {
63- if (npol == 1 )
62+ ModulePW::FFT_Guard guard (wfcpw->fft_bundle );
63+ for (int ib = 0 ; ib < nbands; ib += npol)
6464 {
6565 wfcpw->convolution (this ->ctx ,
66- this ->ik ,
67- this ->veff_col ,
68- tmpsi_in,
69- this ->veff + current_spin* this ->veff_col ,
70- tmhpsi,
71- true );
72- }else {
73- // Should be replaced in the Convolution way.
74- wfcpw->recip_to_real <T,Device>(tmpsi_in, this ->porter , this ->ik );
75- wfcpw->recip_to_real <T,Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
76- if (this ->veff_col != 0 )
77- {
78- // / denghui added at 20221109
79- const Real* current_veff[4 ];
80- for (int is = 0 ; is < 4 ; is++)
81- {
82- current_veff[is] = this ->veff + is * this ->veff_col ; // for CPU device
83- }
84- veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
85- }
86- // FFT back to G space.
87- wfcpw->real_to_recip <T,Device>(this ->porter , tmhpsi, this ->ik , true );
88- wfcpw->real_to_recip <T,Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
66+ this ->ik ,
67+ this ->veff_col ,
68+ tmpsi_in,
69+ this ->veff + current_spin * this ->veff_col ,
70+ tmhpsi,
71+ true );
72+ tmhpsi += psi_offset;
73+ tmpsi_in += psi_offset;
74+ }
75+ }else if (npol == 2 && this ->veff_col !=0 )
76+ {
77+ const Real* current_veff[4 ];
78+ for (int is = 0 ; is < 4 ; is++)
79+ {
80+ current_veff[is] = this ->veff + is * this ->veff_col ;
81+ }
82+ for (int ib = 0 ; ib < nbands; ib += npol)
83+ {
84+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
85+ wfcpw->recip_to_real <T, Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
86+
87+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
88+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
89+ wfcpw->real_to_recip <T, Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
90+ tmhpsi += psi_offset;
91+ tmpsi_in += psi_offset;
8992 }
93+ }else {
94+ ModuleBase::WARNING_QUIT (" VeffPW" , " npol should be 1 or 2 or veff_col equal to 0\n " );
9095 }
91- #endif
92- for ( int ib = 0 ; ib < nbands; ib += npol )
96+ #else
97+ if (npol == 1 && this -> veff_col != 0 )
9398 {
94- if (npol == 1 )
99+ for ( int ib = 0 ; ib < nbands; ib += npol )
95100 {
96- wfcpw->recip_to_real <T,Device>(tmpsi_in, this ->porter , this ->ik );
101+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
97102 // NOTICE: when MPI threads are larger than the number of Z grids
98103 // veff would contain nothing, and nothing should be done in real space
99104 // but the 3DFFT can not be skipped, it will cause hanging
100- if (this ->veff_col != 0 )
101- {
102- veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->veff + current_spin * this ->veff_col );
103- }
104- wfcpw->real_to_recip <T,Device>(this ->porter , tmhpsi, this ->ik , true );
105+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->veff + current_spin * this ->veff_col );
106+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
107+ tmhpsi += psi_offset;
108+ tmpsi_in += psi_offset;
109+ }
110+ }
111+ else if (npol == 2 && this ->veff_col !=0 )
112+ {
113+ const Real* current_veff[4 ];
114+ for (int is = 0 ; is < 4 ; is++)
115+ {
116+ current_veff[is] = this ->veff + is * this ->veff_col ;
105117 }
106- else
118+ for ( int ib = 0 ; ib < nbands; ib += npol)
107119 {
108120 // FFT to real space and do things.
109- wfcpw->recip_to_real <T,Device>(tmpsi_in, this ->porter , this ->ik );
110- wfcpw->recip_to_real <T,Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
111- if (this ->veff_col != 0 )
112- {
113- // / denghui added at 20221109
114- const Real* current_veff[4 ];
115- for (int is = 0 ; is < 4 ; is++)
116- {
117- current_veff[is] = this ->veff + is * this ->veff_col ; // for CPU device
118- }
119- veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
120- }
121+ wfcpw->recip_to_real <T, Device>(tmpsi_in, this ->porter , this ->ik );
122+ wfcpw->recip_to_real <T, Device>(tmpsi_in + max_npw, this ->porter1 , this ->ik );
123+ veff_op ()(this ->ctx , this ->veff_col , this ->porter , this ->porter1 , current_veff);
121124 // FFT back to G space.
122- wfcpw->real_to_recip <T,Device>(this ->porter , tmhpsi, this ->ik , true );
123- wfcpw->real_to_recip <T,Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
125+ wfcpw->real_to_recip <T, Device>(this ->porter , tmhpsi, this ->ik , true );
126+ wfcpw->real_to_recip <T, Device>(this ->porter1 , tmhpsi + max_npw, this ->ik , true );
127+ tmhpsi += max_npw * npol;
128+ tmpsi_in += max_npw * npol;
124129 }
125- tmhpsi += max_npw * npol;
126- tmpsi_in += max_npw * npol ;
130+ } else {
131+ ModuleBase::WARNING_QUIT ( " VeffPW " , " npol should be 1 or 2 or veff_col equal to 0 \n " ) ;
127132 }
133+ #endif
128134 ModuleBase::timer::tick (" Operator" , " veff_pw" );
129135}
130136
0 commit comments