@@ -26,20 +26,50 @@ __global__ void veff_pw(
2626 const int size,
2727 thrust::complex <FPTYPE>* out,
2828 thrust::complex <FPTYPE>* out1,
29- const FPTYPE* in)
29+ thrust:: complex < FPTYPE> * in)
3030{
3131 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
3232 if (idx >= size) {return ;}
33- thrust::complex <FPTYPE> sup =
34- out[idx] * (in[0 * size + idx] + in[3 * size + idx])
35- + out1[idx] * (in[1 * size + idx] - thrust::complex <FPTYPE>(0.0 , 1.0 ) * in[2 * size + idx]);
36- thrust::complex <FPTYPE> sdown =
37- out1[idx] * (in[0 * size + idx] - in[3 * size + idx])
38- + out[idx] * (in[1 * size + idx] + thrust::complex <FPTYPE>(0.0 , 1.0 ) * in[2 * size + idx]);
33+ const int base = idx * 4 ;
34+ thrust::complex <FPTYPE> sup = out[idx] * in[base] + out1[idx] * in[base+1 ];
35+ thrust::complex <FPTYPE> sdown = out1[idx] * in[base+2 ] + out[idx] * in[base+3 ];
3936 out[idx] = sup;
4037 out1[idx] = sdown;
4138}
4239
40+ template <typename FPTYPE>
41+ __global__ void rearrange_op (
42+ const int size,
43+ const FPTYPE* in,
44+ thrust::complex <FPTYPE>* out)
45+ {
46+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
47+ if (idx >= size) {return ;}
48+ const int base = idx * 4 ;
49+ const FPTYPE part_1 = in[idx];
50+ const FPTYPE part_2 = in[idx + size];
51+ const FPTYPE part_3 = in[idx + 2 * size];
52+ const FPTYPE part_4 = in[idx + 3 * size];
53+ out[base] = thrust::complex <FPTYPE>(part_1 + part_4, 0.0 );
54+ out[base + 1 ] = thrust::complex <FPTYPE>(part_2 , -part_3);
55+ out[base + 2 ] = thrust::complex <FPTYPE>(part_1 - part_4, 0.0 );
56+ out[base + 3 ] = thrust::complex <FPTYPE>(part_2, part_3);
57+
58+ }
59+ template <typename FPTYPE>
60+ void rearrange<FPTYPE, base_device::DEVICE_GPU>::operator ()(const base_device::DEVICE_GPU* device,
61+ const int & size,
62+ const FPTYPE* in,
63+ std::complex <FPTYPE>* out) const
64+ {
65+ const int block = (size + THREADS_PER_BLOCK - 1 ) / THREADS_PER_BLOCK;
66+ rearrange_op<FPTYPE><<<block, THREADS_PER_BLOCK>>> (
67+ size, // control params
68+ in, // array of data
69+ reinterpret_cast <thrust::complex <FPTYPE>*>(out)); // array of data
70+ cudaCheckOnDebug ();
71+ }
72+
4373template <typename FPTYPE>
4474void veff_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator ()(const base_device::DEVICE_GPU* dev,
4575 const int & size,
@@ -60,18 +90,20 @@ void veff_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::
6090 const int & size,
6191 std::complex <FPTYPE>* out,
6292 std::complex <FPTYPE>* out1,
63- const FPTYPE* * in)
93+ std:: complex < FPTYPE> * in)
6494{
6595 const int block = (size + THREADS_PER_BLOCK - 1 ) / THREADS_PER_BLOCK;
6696 veff_pw<FPTYPE><<<block, THREADS_PER_BLOCK>>> (
6797 size, // control params
6898 reinterpret_cast <thrust::complex <FPTYPE>*>(out), // array of data
6999 reinterpret_cast <thrust::complex <FPTYPE>*>(out1), // array of data
70- in[ 0 ] ); // array of data
100+ reinterpret_cast <thrust:: complex <FPTYPE>*>(in) ); // array of data
71101
72102 cudaCheckOnDebug ();
73103}
74104
105+ template struct rearrange <float , base_device::DEVICE_GPU>;
106+ template struct rearrange <double , base_device::DEVICE_GPU>;
75107template struct veff_pw_op <float , base_device::DEVICE_GPU>;
76108template struct veff_pw_op <double , base_device::DEVICE_GPU>;
77109
0 commit comments