Skip to content

Commit 071a827

Browse files
committed
add change for gpu
1 parent eab71bb commit 071a827

File tree

5 files changed

+90
-29
lines changed

5 files changed

+90
-29
lines changed

source/source_basis/module_pw/pw_basis_sup.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,7 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy(
412412
{
413413
int z = iz;
414414
if (z < 0) {
415-
z += this->nz;
416-
}
415+
z += this->nz;}
417416
if (!found[ixy * this->nz + z])
418417
{
419418
found[ixy * this->nz + z] = true;
@@ -422,7 +421,7 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy(
422421
pw_filled++;
423422
if (xprime && ixy / fftny == 0) {
424423
ng_xeq0++;
425-
}
424+
}
426425
}
427426
}
428427
}

source/source_pw/module_pwdft/kernels/cuda/veff_op.cu

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,50 @@ __global__ void veff_pw(
2626
const int size,
2727
thrust::complex<FPTYPE>* out,
2828
thrust::complex<FPTYPE>* out1,
29-
const FPTYPE* in)
29+
thrust::complex<FPTYPE>* in)
3030
{
3131
int idx = blockIdx.x * blockDim.x + threadIdx.x;
3232
if(idx >= size) {return;}
33-
thrust::complex<FPTYPE> sup =
34-
out[idx] * (in[0 * size + idx] + in[3 * size + idx])
35-
+ out1[idx] * (in[1 * size + idx] - thrust::complex<FPTYPE>(0.0, 1.0) * in[2 * size + idx]);
36-
thrust::complex<FPTYPE> sdown =
37-
out1[idx] * (in[0 * size + idx] - in[3 * size + idx])
38-
+ out[idx] * (in[1 * size + idx] + thrust::complex<FPTYPE>(0.0, 1.0) * in[2 * size + idx]);
33+
const int base = idx * 4;
34+
thrust::complex<FPTYPE> sup = out[idx] * in[base] + out1[idx] * in[base+1];
35+
thrust::complex<FPTYPE> sdown = out1[idx] * in[base+2] + out[idx] * in[base+3];
3936
out[idx] = sup;
4037
out1[idx] = sdown;
4138
}
4239

40+
template <typename FPTYPE>
41+
__global__ void rearrange_op(
42+
const int size,
43+
const FPTYPE* in,
44+
thrust::complex<FPTYPE>* out)
45+
{
46+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
47+
if(idx >= size) {return;}
48+
const int base = idx * 4;
49+
const FPTYPE part_1 = in[idx];
50+
const FPTYPE part_2 = in[idx + size];
51+
const FPTYPE part_3 = in[idx + 2 * size];
52+
const FPTYPE part_4 = in[idx + 3 * size];
53+
out[base] = thrust::complex<FPTYPE>(part_1 + part_4, 0.0);
54+
out[base + 1] = thrust::complex<FPTYPE>(part_2 , -part_3);
55+
out[base + 2] = thrust::complex<FPTYPE>(part_1 - part_4, 0.0);
56+
out[base + 3] = thrust::complex<FPTYPE>(part_2, part_3);
57+
58+
}
59+
template <typename FPTYPE>
60+
void rearrange<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* device,
61+
const int& size,
62+
const FPTYPE* in,
63+
std::complex<FPTYPE>* out) const
64+
{
65+
const int block = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
66+
rearrange_op<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
67+
size, // control params
68+
in, // array of data
69+
reinterpret_cast<thrust::complex<FPTYPE>*>(out)); // array of data
70+
cudaCheckOnDebug();
71+
}
72+
4373
template <typename FPTYPE>
4474
void veff_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
4575
const int& size,
@@ -60,18 +90,20 @@ void veff_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::
6090
const int& size,
6191
std::complex<FPTYPE>* out,
6292
std::complex<FPTYPE>* out1,
63-
const FPTYPE** in)
93+
std::complex<FPTYPE>* in)
6494
{
6595
const int block = (size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
6696
veff_pw<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
6797
size, // control params
6898
reinterpret_cast<thrust::complex<FPTYPE>*>(out), // array of data
6999
reinterpret_cast<thrust::complex<FPTYPE>*>(out1), // array of data
70-
in[0]); // array of data
100+
reinterpret_cast<thrust::complex<FPTYPE>*>(in)); // array of data
71101

72102
cudaCheckOnDebug();
73103
}
74104

105+
template struct rearrange<float, base_device::DEVICE_GPU>;
106+
template struct rearrange<double, base_device::DEVICE_GPU>;
75107
template struct veff_pw_op<float, base_device::DEVICE_GPU>;
76108
template struct veff_pw_op<double, base_device::DEVICE_GPU>;
77109

source/source_pw/module_pwdft/kernels/veff_op.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct veff_pw_op<FPTYPE, base_device::DEVICE_CPU>
2020
const int& size,
2121
std::complex<FPTYPE>* out,
2222
std::complex<FPTYPE>* out1,
23-
const std::complex<FPTYPE>* in)
23+
std::complex<FPTYPE>* in)
2424
{
2525

2626
#ifdef _OPENMP
@@ -37,6 +37,31 @@ struct veff_pw_op<FPTYPE, base_device::DEVICE_CPU>
3737
}
3838
};
3939

40+
template<typename FPTYPE>
41+
struct rearrange<FPTYPE, base_device::DEVICE_CPU>
42+
{
43+
void operator()(const base_device::DEVICE_CPU* dev,
44+
const int& size,
45+
const FPTYPE* in,
46+
std::complex<FPTYPE>* out) const
47+
{
48+
for (int ir=0; ir < size; ir++)
49+
{
50+
const int base = 4 *ir;
51+
FPTYPE part_1 = in[ir];
52+
FPTYPE part_2 = in[ir + size];
53+
FPTYPE part_3 = in[ir + 2*size];
54+
FPTYPE part_4 = in[ir + 3*size];
55+
out[base ] = std::complex<FPTYPE>(part_1 + part_4, 0.0);
56+
out[base + 1] = std::complex<FPTYPE>(part_2 , -part_3);
57+
out[base + 2] = std::complex<FPTYPE>(part_1 - part_4, 0.0);
58+
out[base + 3] = std::complex<FPTYPE>(part_2, part_3);
59+
}
60+
}
61+
};
62+
63+
template struct rearrange<float, base_device::DEVICE_CPU>;
64+
template struct rearrange<double, base_device::DEVICE_CPU>;
4065
template struct veff_pw_op<float, base_device::DEVICE_CPU>;
4166
template struct veff_pw_op<double, base_device::DEVICE_CPU>;
4267

source/source_pw/module_pwdft/kernels/veff_op.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct veff_pw_op {
4848
const int& size,
4949
std::complex<FPTYPE>* out,
5050
std::complex<FPTYPE>* out1,
51-
const std::complex<FPTYPE>* in);
51+
std::complex<FPTYPE>* in);
5252
};
5353

5454
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
@@ -62,7 +62,25 @@ struct veff_pw_op<FPTYPE, base_device::DEVICE_GPU>
6262
const int& size,
6363
std::complex<FPTYPE>* out,
6464
std::complex<FPTYPE>* out1,
65-
const FPTYPE** in);
65+
std::complex<FPTYPE>* in);
66+
};
67+
68+
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
69+
template<typename FPTYPE, typename Device>
70+
struct rearrange
71+
{
72+
void operator()(const Device* device,const int& size, const FPTYPE* in, std::complex<FPTYPE>* out) const;
73+
};
74+
75+
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
76+
77+
template<typename FPTYPE>
78+
struct rearrange<FPTYPE, base_device::DEVICE_GPU>
79+
{
80+
void operator()(const base_device::DEVICE_GPU* device,
81+
const int& size,
82+
const FPTYPE* in,
83+
std::complex<FPTYPE>* out) const;
6684
};
6785
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
6886
} // namespace hamilt

source/source_pw/module_pwdft/operator_pw/veff_pw.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,7 @@ void Veff<OperatorPW<T, Device>>::act(
114114
}
115115
else if (npol == 2)
116116
{
117-
const Real* current_veff={nullptr};
118-
const std::complex<Real> imag=std::complex<Real>(0.0, 1.0);
119-
for (int ir=0; ir < veff_col; ir++)
120-
{
121-
const int base = 4 *ir;
122-
Real part_1 = this->veff[ir];
123-
Real part_2 = this->veff[ir + veff_col];
124-
Real part_3 = this->veff[ir + 2*veff_col];
125-
Real part_4 = this->veff[ir + 3*veff_col];
126-
nspin_4_veff[base ] = part_1 + part_4;
127-
nspin_4_veff[base + 1] = part_2 - imag * part_3;
128-
nspin_4_veff[base + 2] = part_1 - part_4;
129-
nspin_4_veff[base + 3] = part_2 + imag * part_3;
130-
}
117+
rearrange<Real,Device>()(this->ctx, this->veff_col, this->veff, this->nspin_4_veff);
131118
for (int ib = 0; ib < nbands; ib += npol)
132119
{
133120
// FFT to real space and do things.

0 commit comments

Comments
 (0)