Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ OBJS_PW=fft_bundle.o\
pw_init.o\
pw_transform.o\
pw_transform_k.o\
pw_transform_convolution.o\

OBJS_RELAXATION=bfgs_basic.o\
relax_driver.o\
Expand Down
1 change: 1 addition & 0 deletions source/source_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ list(APPEND objects
pw_transform.cpp
pw_transform_gpu.cpp
pw_transform_k.cpp
pw_transform_convolution.cpp
module_fft/fft_bundle.cpp
module_fft/fft_cpu.cpp
${FFT_SRC}
Expand Down
4 changes: 2 additions & 2 deletions source/source_basis/module_pw/module_fft/fft_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ template <>
void FFT_CUDA<float>::setupFFT()
{
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);
resmem_cd_op()(this->c_auxr_3d, 2*this->nx * this->ny * this->nz);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why '2' is added here?

}
template <>
void FFT_CUDA<double>::setupFFT()
{
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
resmem_zd_op()(this->z_auxr_3d, this->nx * this->ny * this->nz);
resmem_zd_op()(this->z_auxr_3d, 2*this->nx * this->ny * this->nz);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why '2' is added here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error has been fixed.

}
template <>
void FFT_CUDA<float>::cleanFFT()
Expand Down
92 changes: 90 additions & 2 deletions source/source_basis/module_pw/pw_basis_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class PW_Basis_K : public PW_Basis
const int ik,
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)

#endif

template <typename FPTYPE, typename Device>
Expand All @@ -176,7 +175,6 @@ class PW_Basis_K : public PW_Basis
const bool add = false,
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)


template <typename TK,
typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
Expand Down Expand Up @@ -245,7 +243,97 @@ class PW_Basis_K : public PW_Basis
{
this->recip2real_gpu(in, out, ik, add, factor);
}
template <typename FPTYPE, typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
void convolution(const int ik,
const int size,
const FPTYPE* input,
const typename GetTypeReal<FPTYPE>::type* input1,
FPTYPE* output,
const bool add = false,
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
{
this->convolution_gpu(ik, size,input, input1, output, add, factor);
}
template <typename FPTYPE, typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
void convolution(const int ik,
const int size,
const int max_npw,
const FPTYPE* input,
FPTYPE* tmp,
FPTYPE* input1,
FPTYPE* output,
const bool add =false,
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
{
this->convolution_cpu(ik, size, max_npw,input, tmp, input1, output ,add,factor);
}

template <typename FPTYPE, typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
void convolution(const int ik,
const int size,
const int max_npw,
const FPTYPE* input,
FPTYPE* tmp,
FPTYPE* input1,
FPTYPE* output,
const bool add=false,
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
{
this->convolution_gpu(ik, size, max_npw,input, tmp, input1, output ,add ,factor);
}
template <typename FPTYPE, typename Device,
typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
void convolution(const int ik,
const int size,
const FPTYPE* input,
const typename GetTypeReal<FPTYPE>::type* input1,
FPTYPE* output,
const bool add = false,
const typename GetTypeReal<FPTYPE>::type factor =1.0) const
{
this->convolution_cpu(ik, size, input, input1, output, add, factor);
}
template <typename FPTYPE>
void convolution_cpu(const int ik,
const int size,
const std::complex<FPTYPE>* input,
const FPTYPE* input1,
std::complex<FPTYPE>* output,
const bool add = false,
const FPTYPE factor = 1.0) const;

template <typename FPTYPE>
void convolution_gpu(const int ik,
const int size,
const std::complex<FPTYPE>* input,
const FPTYPE* input1,
std::complex<FPTYPE>* output,
const bool add = false,
const FPTYPE factor = 1.0) const;
template <typename FPTYPE>
void convolution_cpu(const int ik,
const int size,
const int max_npw,
const std::complex<FPTYPE>* input,
std::complex<FPTYPE>* tmp,
std::complex<FPTYPE>* input1,
std::complex<FPTYPE>* output,
const bool add = false,
const FPTYPE factor = 1.0) const;

template <typename FPTYPE>
void convolution_gpu(const int ik,
const int size,
const int max_npw,
const std::complex<FPTYPE>* input,
std::complex<FPTYPE>* tmp,
std::complex<FPTYPE>* input1,
std::complex<FPTYPE>* output,
const bool add = false,
const FPTYPE factor = 1.0) const;
public:
//operator:
//get (G+K)^2:
Expand Down
5 changes: 2 additions & 3 deletions source/source_basis/module_pw/pw_basis_sup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,7 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy(
{
int z = iz;
if (z < 0) {
z += this->nz;
}
z += this->nz;}
if (!found[ixy * this->nz + z])
{
found[ixy * this->nz + z] = true;
Expand All @@ -422,7 +421,7 @@ void PW_Basis_Sup::get_ig2isz_is2fftixy(
pw_filled++;
if (xprime && ixy / fftny == 0) {
ng_xeq0++;
}
}
}
}
}
Expand Down
Loading
Loading