Skip to content

Commit eab71bb

Browse files
committed
add change for the npsin_4
1 parent 4b770fb commit eab71bb

File tree

4 files changed

+25
-15
lines changed

4 files changed

+25
-15
lines changed

source/source_pw/module_pwdft/kernels/veff_op.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,17 @@ struct veff_pw_op<FPTYPE, base_device::DEVICE_CPU>
2020
const int& size,
2121
std::complex<FPTYPE>* out,
2222
std::complex<FPTYPE>* out1,
23-
const FPTYPE** in)
23+
const std::complex<FPTYPE>* in)
2424
{
25+
2526
#ifdef _OPENMP
2627
#pragma omp parallel for
2728
#endif
28-
for (int ir = 0; ir < size; ir++) {
29-
auto sup = out[ir] * (in[0][ir] + in[3][ir])
30-
+ out1[ir]
31-
* (in[1][ir]
32-
- std::complex<FPTYPE>(0.0, 1.0) * in[2][ir]);
33-
auto sdown = out1[ir] * (in[0][ir] - in[3][ir])
34-
+ out[ir]
35-
* (in[1][ir]
36-
+ std::complex<FPTYPE>(0.0, 1.0) * in[2][ir]);
29+
for (int ir = 0; ir < size; ir++)
30+
{
31+
const int base = ir * 4;
32+
auto sup = out[ir] * (in[base]) + out1[ir] * (in[base + 1]);
33+
auto sdown = out1[ir] * (in[base + 2]) + out[ir] * (in[base + 3]);
3734
out[ir] = sup;
3835
out1[ir] = sdown;
3936
}

source/source_pw/module_pwdft/kernels/veff_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct veff_pw_op {
4848
const int& size,
4949
std::complex<FPTYPE>* out,
5050
std::complex<FPTYPE>* out1,
51-
const FPTYPE** in);
51+
const std::complex<FPTYPE>* in);
5252
};
5353

5454
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

source/source_pw/module_pwdft/operator_pw/veff_pw.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Veff<OperatorPW<T, Device>>::Veff(const int* isk_in,
2525
this->veff_row = veff_row;
2626
this->veff_col = veff_col;
2727
this->wfcpw = wfcpw_in;
28+
resmem_complex_op()(nspin_4_veff, 4*veff_col, "Veff<PW>::porter");
29+
// this->nspin_4_veff=new std::complex<double>[4*veff_row];
2830
resmem_complex_op()(this->porter, this->wfcpw->nmaxgr, "Veff<PW>::porter");
2931
resmem_complex_op()(this->porter1, this->wfcpw->nmaxgr, "Veff<PW>::porter1");
3032

@@ -35,6 +37,7 @@ Veff<OperatorPW<T, Device>>::~Veff()
3537
{
3638
delmem_complex_op()(this->porter);
3739
delmem_complex_op()(this->porter1);
40+
delmem_complex_op()(this->nspin_4_veff);
3841
}
3942

4043
template<typename T, typename Device>
@@ -111,17 +114,26 @@ void Veff<OperatorPW<T, Device>>::act(
111114
}
112115
else if (npol == 2)
113116
{
114-
const Real* current_veff[4]={nullptr};
115-
for (int is = 0; is < 4; is++)
117+
const Real* current_veff={nullptr};
118+
const std::complex<Real> imag=std::complex<Real>(0.0, 1.0);
119+
for (int ir=0; ir < veff_col; ir++)
116120
{
117-
current_veff[is] = this->veff + is * this->veff_col;
121+
const int base = 4 *ir;
122+
Real part_1 = this->veff[ir];
123+
Real part_2 = this->veff[ir + veff_col];
124+
Real part_3 = this->veff[ir + 2*veff_col];
125+
Real part_4 = this->veff[ir + 3*veff_col];
126+
nspin_4_veff[base ] = part_1 + part_4;
127+
nspin_4_veff[base + 1] = part_2 - imag * part_3;
128+
nspin_4_veff[base + 2] = part_1 - part_4;
129+
nspin_4_veff[base + 3] = part_2 + imag * part_3;
118130
}
119131
for (int ib = 0; ib < nbands; ib += npol)
120132
{
121133
// FFT to real space and do things.
122134
wfcpw->recip_to_real<T, Device>(tmpsi_in, this->porter, this->ik);
123135
wfcpw->recip_to_real<T, Device>(tmpsi_in + max_npw, this->porter1, this->ik);
124-
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
136+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, nspin_4_veff);
125137
// FFT back to G space.
126138
wfcpw->real_to_recip<T, Device>(this->porter, tmhpsi, this->ik, true);
127139
wfcpw->real_to_recip<T, Device>(this->porter1, tmhpsi + max_npw, this->ik, true);

source/source_pw/module_pwdft/operator_pw/veff_pw.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class Veff<OperatorPW<T, Device>> : public OperatorPW<T, Device>
6868
const Real *veff = nullptr, *h_veff = nullptr, *d_veff = nullptr;
6969
T *porter = nullptr;
7070
T *porter1 = nullptr;
71+
mutable T* nspin_4_veff=nullptr;
7172
base_device::AbacusDevice_t device = {};
7273
using veff_op = veff_pw_op<Real, Device>;
7374

0 commit comments

Comments
 (0)