Skip to content

Commit 63fb3dd

Browse files
committed
change compute mode
1 parent 4480e17 commit 63fb3dd

File tree

1 file changed

+59
-53
lines changed
  • source/module_hamilt_pw/hamilt_pwdft/operator_pw

1 file changed

+59
-53
lines changed

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -55,76 +55,82 @@ void Veff<OperatorPW<T, Device>>::act(
5555

5656
int max_npw = nbasis / npol;
5757
const int current_spin = this->isk[this->ik];
58-
58+
const int psi_offset= max_npw * npol;
5959
#ifdef __DSP
60-
ModulePW::FFT_Guard guard(wfcpw->fft_bundle);
61-
for (int ib = 0; ib<nbands ; ib += npol)
60+
if (npol == 1 && this->veff_col !=0)
6261
{
63-
if (npol == 1)
62+
ModulePW::FFT_Guard guard(wfcpw->fft_bundle);
63+
for (int ib = 0; ib < nbands; ib += npol)
6464
{
6565
wfcpw->convolution(this->ctx,
66-
this->ik,
67-
this->veff_col,
68-
tmpsi_in,
69-
this->veff+ current_spin* this->veff_col,
70-
tmhpsi,
71-
true);
72-
}else{
73-
// Should be replaced in the Convolution way.
74-
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
75-
wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
76-
if(this->veff_col != 0)
77-
{
78-
/// denghui added at 20221109
79-
const Real* current_veff[4];
80-
for(int is = 0; is < 4; is++)
81-
{
82-
current_veff[is] = this->veff + is * this->veff_col ; // for CPU device
83-
}
84-
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
85-
}
86-
// FFT back to G space.
87-
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
88-
wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
66+
this->ik,
67+
this->veff_col,
68+
tmpsi_in,
69+
this->veff + current_spin * this->veff_col,
70+
tmhpsi,
71+
true);
72+
tmhpsi += psi_offset;
73+
tmpsi_in += psi_offset;
74+
}
75+
}else if (npol == 2 && this->veff_col !=0)
76+
{
77+
const Real* current_veff[4];
78+
for (int is = 0; is < 4; is++)
79+
{
80+
current_veff[is] = this->veff + is * this->veff_col;
81+
}
82+
for (int ib = 0; ib < nbands; ib += npol)
83+
{
84+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
85+
wfcpw->recip_to_real<T, Device>(tmpsi_in + max_npw, this->porter1, this->ik);
86+
87+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
88+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
89+
wfcpw->real_to_recip<T, Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
90+
tmhpsi += psi_offset;
91+
tmpsi_in += psi_offset;
8992
}
93+
}else{
94+
ModuleBase::WARNING_QUIT("VeffPW", "npol should be 1 or 2 or veff_col equal to 0\n");
9095
}
91-
#endif
92-
for (int ib = 0; ib < nbands; ib += npol)
96+
#else
97+
if (npol == 1 && this->veff_col !=0)
9398
{
94-
if (npol == 1)
99+
for (int ib = 0; ib < nbands; ib += npol)
95100
{
96-
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
101+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
97102
// NOTICE: when MPI threads are larger than the number of Z grids
98103
// veff would contain nothing, and nothing should be done in real space
99104
// but the 3DFFT can not be skipped, it will cause hanging
100-
if(this->veff_col != 0)
101-
{
102-
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
103-
}
104-
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
105+
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
106+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
107+
tmhpsi += psi_offset;
108+
tmpsi_in += psi_offset;
109+
}
110+
}
111+
else if (npol == 2 && this->veff_col !=0)
112+
{
113+
const Real* current_veff[4];
114+
for (int is = 0; is < 4; is++)
115+
{
116+
current_veff[is] = this->veff + is * this->veff_col;
105117
}
106-
else
118+
for (int ib = 0; ib < nbands; ib += npol)
107119
{
108120
// FFT to real space and do things.
109-
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
110-
wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
111-
if(this->veff_col != 0)
112-
{
113-
/// denghui added at 20221109
114-
const Real* current_veff[4];
115-
for(int is = 0; is < 4; is++)
116-
{
117-
current_veff[is] = this->veff + is * this->veff_col ; // for CPU device
118-
}
119-
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
120-
}
121+
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
122+
wfcpw->recip_to_real<T, Device>(tmpsi_in + max_npw, this->porter1, this->ik);
123+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
121124
// FFT back to G space.
122-
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
123-
wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
125+
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
126+
wfcpw->real_to_recip<T, Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
127+
tmhpsi += max_npw * npol;
128+
tmpsi_in += max_npw * npol;
124129
}
125-
tmhpsi += max_npw * npol;
126-
tmpsi_in += max_npw * npol;
130+
}else{
131+
ModuleBase::WARNING_QUIT("VeffPW", "npol should be 1 or 2 or veff_col equal to 0\n");
127132
}
133+
#endif
128134
ModuleBase::timer::tick("Operator", "veff_pw");
129135
}
130136

0 commit comments

Comments
 (0)