Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
f7b8e0d
first use of the fft
A-006 Mar 25, 2025
af7e835
fix the cuda test
A-006 Mar 31, 2025
eefe6c5
add gtest
A-006 Mar 31, 2025
150e63d
change the test
A-006 Mar 31, 2025
ecc7f23
change recip_to_real func
A-006 Mar 31, 2025
95c1108
set C2C
A-006 Mar 31, 2025
2fdecfa
set C2C and C2R
A-006 Apr 1, 2025
0f3bf0a
set C2C and C2R
A-006 Apr 1, 2025
79a48f0
add pw_basis_k C2C
A-006 Apr 1, 2025
5065730
add CmakeLists.txt
A-006 Apr 1, 2025
9abbfe2
Merge branch 'develop' into fft12
A-006 Apr 1, 2025
019a2b9
remove compile
A-006 Apr 1, 2025
d88177f
add the define
A-006 Apr 1, 2025
6820901
add delete
A-006 Apr 1, 2025
d1a8a05
Merge branch 'develop' into fft12
A-006 Apr 1, 2025
09769ef
Merge branch 'develop' into fft12
A-006 Apr 1, 2025
7d5d77c
modify the include file
A-006 Apr 2, 2025
a96be8f
Merge branch 'develop' into fft12
A-006 Apr 2, 2025
f482e59
add change for the POOL_WORLD
A-006 Apr 2, 2025
c8820e3
Merge branch 'develop' into fft12
A-006 Apr 2, 2025
10f0ad2
format document
A-006 Apr 3, 2025
069dd0c
Merge branch 'develop' into fft12
A-006 Apr 3, 2025
82bf06f
Merge branch 'develop' into fft12
A-006 Apr 4, 2025
501580a
add the mpi_flag_
A-006 Apr 5, 2025
88b3605
Merge branch 'develop' into fft12
A-006 Apr 5, 2025
ada3a8d
fix bug in the mpi set
A-006 Apr 6, 2025
eb419aa
Merge branch 'develop' into fft12
A-006 Apr 6, 2025
597d012
modify the mpi_flag_ as flase
A-006 Apr 7, 2025
26ff18e
Merge branch 'develop' into fft12
A-006 Apr 8, 2025
85785a1
set pw_test
A-006 Apr 8, 2025
d977a71
Merge branch 'develop' into fft12
A-006 Apr 8, 2025
e240efc
Revert "set pw_test"
A-006 Apr 8, 2025
e67db12
revert setup
A-006 Apr 8, 2025
9f41593
Merge branch 'develop' into fft12
A-006 Apr 9, 2025
b4d1226
Merge branch 'develop' into fft12
Qianruipku Apr 9, 2025
10843f5
add comment
A-006 Apr 9, 2025
7490f7e
Merge branch 'develop' into fft12
A-006 Apr 10, 2025
7280235
update offset
A-006 Apr 12, 2025
fc6ecc6
Merge branch 'develop' into fft12
A-006 Apr 12, 2025
af0b97f
Merge branch 'develop' into fft12
A-006 Apr 14, 2025
e5f5a42
reset the for
A-006 Apr 18, 2025
cedaae6
change the mpi_flag
A-006 Apr 18, 2025
c5d6524
change the compile error
A-006 Apr 18, 2025
b2e1c4a
Merge branch 'develop' into fft12
A-006 Apr 18, 2025
1026030
sperate two lines
A-006 Apr 18, 2025
d382d6f
add change
A-006 Apr 21, 2025
3c02d02
use template instead of change
A-006 Apr 22, 2025
dc7791e
remove tem file
A-006 Apr 22, 2025
17212af
Merge branch 'develop' into fft12
A-006 Apr 22, 2025
ad2de2b
Merge branch 'develop' into fft12
mohanchen Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/module_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ struct synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVIC
void operator()(FPTYPE* arr_out,
const FPTYPE* arr_in,
const size_t size);

};

template <typename FPTYPE>
Expand Down
1 change: 1 addition & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ if(BUILD_TESTING)
add_subdirectory(test)
add_subdirectory(test_serial)
add_subdirectory(kernels/test)
add_subdirectory(test_gpu)
endif()
endif()
52 changes: 26 additions & 26 deletions source/module_basis/module_pw/module_fft/fft_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
#include "fft_cuda.h"

#include "module_base/module_device/memory_op.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"

namespace ModulePW
{
template <typename FPTYPE>
void FFT_CUDA<FPTYPE>::initfft(int nx_in,
int ny_in,
int nz_in)
void FFT_CUDA<FPTYPE>::initfft(int nx_in, int ny_in, int nz_in)
{
this->nx = nx_in;
this->ny = ny_in;
Expand All @@ -18,9 +17,8 @@ void FFT_CUDA<float>::setupFFT()
{
cufftPlan3d(&c_handle, this->nx, this->ny, this->nz, CUFFT_C2C);
resmem_cd_op()(this->c_auxr_3d, this->nx * this->ny * this->nz);

}
template <>
template <>
void FFT_CUDA<double>::setupFFT()
{
cufftPlan3d(&z_handle, this->nx, this->ny, this->nz, CUFFT_Z2Z);
Expand Down Expand Up @@ -66,49 +64,51 @@ void FFT_CUDA<double>::clear()
}

template <>
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in,
std::complex<float>* out) const
void FFT_CUDA<float>::fft3D_forward(std::complex<float>* in, std::complex<float>* out) const
{
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
reinterpret_cast<cufftComplex*>(out),
CUFFT_FORWARD));
}
template <>
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in,
std::complex<double>* out) const
void FFT_CUDA<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
{
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
reinterpret_cast<cufftDoubleComplex*>(in),
reinterpret_cast<cufftDoubleComplex*>(out),
reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_FORWARD));
}
template <>
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in,
std::complex<float>* out) const
void FFT_CUDA<float>::fft3D_backward(std::complex<float>* in, std::complex<float>* out) const
{
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
CHECK_CUFFT(cufftExecC2C(this->c_handle,
reinterpret_cast<cufftComplex*>(in),
reinterpret_cast<cufftComplex*>(out),
CUFFT_INVERSE));
}

template <>
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in,
std::complex<double>* out) const
void FFT_CUDA<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
{
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
CHECK_CUFFT(cufftExecZ2Z(this->z_handle,
reinterpret_cast<cufftDoubleComplex*>(in),
reinterpret_cast<cufftDoubleComplex*>(out),
reinterpret_cast<cufftDoubleComplex*>(out),
CUFFT_INVERSE));
}
template <> std::complex<float>*
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
template <> std::complex<double>*
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
template <>
std::complex<float>* FFT_CUDA<float>::get_auxr_3d_data() const
{
return this->c_auxr_3d;
}
template <>
std::complex<double>* FFT_CUDA<double>::get_auxr_3d_data() const
{
return this->z_auxr_3d;
}

template FFT_CUDA<float>::FFT_CUDA();
template FFT_CUDA<float>::~FFT_CUDA();
template FFT_CUDA<double>::FFT_CUDA();
template FFT_CUDA<double>::~FFT_CUDA();
}// namespace ModulePW
} // namespace ModulePW
2 changes: 2 additions & 0 deletions source/module_basis/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ PW_Basis:: ~PW_Basis()
if (this->device == "gpu")
{
delmem_int_op()(this->d_is2fftixy);
delmem_int_op()(this->ig2ixyz_gpu);
}
#endif
}
Expand All @@ -59,6 +60,7 @@ void PW_Basis::setuptransform()
this->distribute_g();
this->getstartgr();
this->fft_bundle.clear();

if(this->xprime)
{
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
Expand Down
22 changes: 14 additions & 8 deletions source/module_basis/module_pw/pw_basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ class PW_Basis
//===============================================
public:
#ifdef __MPI
MPI_Comm pool_world;
MPI_Comm pool_world=MPI_COMM_NULL;
#endif

int *ig2isz=nullptr; // map ig to (is, iz).
int *ig2ixyz_gpu = nullptr;
int *istot2ixy=nullptr; // istot2ixy[is]: iy + ix * ny of is^th stick among all sticks.
int *is2fftixy=nullptr, * d_is2fftixy = nullptr; // is2fftixy[is]: iy + ix * ny of is^th stick among sticks on current proc.
int *fftixy2ip=nullptr; // fftixy2ip[iy + ix * fftny]: ip of proc which contains stick on (ix, iy). if no stick: -1
Expand Down Expand Up @@ -352,7 +353,10 @@ class PW_Basis
void recip_to_real(TK* in,
TR* out,
const bool add = false,
const typename GetTypeReal<TK>::type factor = 1.0) const;
const typename GetTypeReal<TK>::type factor = 1.0) const
{
this->recip2real_gpu(in,out,add,factor);
};

// template <typename FPTYPE,
// typename Device,
Expand Down Expand Up @@ -383,9 +387,7 @@ class PW_Basis
* values in the output array.
* @param factor Optional scaling factor, default value 1.0, applied to the output values.
*/
template <typename TK,
typename TR,
typename Device,
template <typename TR,typename TK,typename Device,
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
&& std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
Expand All @@ -397,14 +399,17 @@ class PW_Basis
this->real2recip(in, out, add, factor);
}

template <typename TK,typename TR, typename Device,
template <typename TR, typename TK, typename Device,
typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
&& (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
&& !std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
&& std::is_same<Device, base_device::DEVICE_GPU>::value ,int>::type = 0>
void real_to_recip(TR* in,
TK* out,
const bool add = false,
const typename GetTypeReal<TK>::type factor = 1.0) const;
const typename GetTypeReal<TK>::type factor = 1.0) const
{
this->real2recip_gpu(in,out,add,factor);
};

protected:
//gather planes and scatter sticks of all processors
Expand All @@ -431,6 +436,7 @@ class PW_Basis

std::string device = "cpu"; ///< cpu or gpu
std::string precision = "double"; ///< single, double, mixing
bool mpi_flag_ = false; ///< ture,is use mpi or not
bool double_data_ = true; ///< if has double data
bool float_data_ = false; ///< if has float data
};
Expand Down
7 changes: 6 additions & 1 deletion source/module_basis/module_pw/pw_basis_big.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ class PW_Basis_Big : public PW_Basis_Sup
ibox[0] = 2*n1+1;
ibox[1] = 2*n2+1;
ibox[2] = 2*n3+1;
if (mpi_flag_)
{
#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, ibox, 3, MPI_INT, MPI_MAX , this->pool_world);
#endif

}

// Find the minimal FFT box size the factors into the primes (2,3,5,7).
for (int i = 0; i < 3; i++)
Expand Down Expand Up @@ -350,9 +352,12 @@ class PW_Basis_Big : public PW_Basis_Sup
}
}
}
if (mpi_flag_)
{
#ifdef __MPI
MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world);
#endif
}
this->gridecut_lat -= 1e-6;

delete[] ibox;
Expand Down
3 changes: 1 addition & 2 deletions source/module_basis/module_pw/pw_basis_k_big.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class PW_Basis_K_Big: public PW_Basis_K
for(int ip = 0 ; ip < this->poolnproc ; ++ip)
{
this->numz[ip] = npbz*this->bz;
if(ip < modbz) { this->numz[ip]+=this->bz;
}
if(ip < modbz) { this->numz[ip]+=this->bz;}
if(ip < this->poolnproc - 1) this->startz[ip+1] = this->startz[ip] + numz[ip];
if(ip == this->poolrank)
{
Expand Down
40 changes: 25 additions & 15 deletions source/module_basis/module_pw/pw_basis_sup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,25 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
// calculate this->nstot and this->npwtot, liy, riy
this->count_pw_st(st_length2D, st_bottom2D);
}
if (mpi_flag_)
{
#ifdef __MPI
MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);

MPI_Bcast(&this->npwtot, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&this->nstot, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&liy, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&riy, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&lix, 1, MPI_INT, 0, this->pool_world);
MPI_Bcast(&rix, 1, MPI_INT, 0, this->pool_world);
#endif
}
delete[] this->istot2ixy;
this->istot2ixy = new int[this->nstot];

if (poolrank == 0)
{
if (mpi_flag_)
{
#ifdef __MPI
// Parallel line
// (2) Collect the x, y indexs, and length of the sticks.
Expand All @@ -147,7 +153,8 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
// We do not need startnsz_per after it.
delete[] this->startnsz_per;
this->startnsz_per = nullptr;
#else
#endif
}else{
// Serial line
// get nst_per, npw_per, fftixy2ip, and istot2ixy
this->nst_per[0] = this->nstot;
Expand All @@ -162,17 +169,20 @@ void PW_Basis_Sup::distribution_method3(const ModulePW::PW_Basis* pw_rho)
st_move++;
}
}
#endif
}

}
if (mpi_flag_)
{
#ifdef __MPI
MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);

MPI_Bcast(st_length2D, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(st_bottom2D, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->fftixy2ip, this->fftnxy, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->istot2ixy, this->nstot, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->nst_per, this->poolnproc, MPI_INT, 0, this->pool_world);
MPI_Bcast(this->npw_per, this->poolnproc, MPI_INT, 0, this->pool_world);
#endif
}
this->npw = this->npw_per[this->poolrank];
this->nst = this->nst_per[this->poolrank];
this->nstnz = this->nst * this->nz;
Expand Down
31 changes: 25 additions & 6 deletions source/module_basis/module_pw/pw_distributeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,23 +189,42 @@ void PW_Basis::get_ig2isz_is2fftixy(
for (int iz = zstart; iz < zstart + st_length2D[ixy]; ++iz)
{
int z = iz;
if (z < 0) { z += this->nz;
}
if (z < 0)
{
z += this->nz;
}
this->ig2isz[pw_filled] = st_move * this->nz + z;
pw_filled++;
}
this->is2fftixy[st_move] = ixy;
st_move++;
if(xprime && ixy/fftny == 0) { ng_xeq0 = pw_filled;
}
if (xprime && ixy / fftny == 0)
{
ng_xeq0 = pw_filled;
}
}
if (st_move == this->nst && pw_filled == this->npw) { break;
}
if (st_move == this->nst && pw_filled == this->npw)
{
break;
}
}
std::vector<int> ig2ixyz(this->npw);
for (int igl = 0; igl < this->npw; ++igl)
{
int isz = this->ig2isz[igl];
int iz = isz % this->nz;
int is = isz / this->nz;
int ixy = this->is2fftixy[is];
int iy = ixy % this->ny;
int ix = ixy / this->ny;
ig2ixyz[igl] = iz + iy * nz + ix * ny * nz;
}
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu") {
resmem_int_op()(d_is2fftixy, this->nst);
syncmem_int_h2d_op()(this->d_is2fftixy, this->is2fftixy, this->nst);
resmem_int_op()(ig2ixyz_gpu,this->npw);
syncmem_int_h2d_op()(ig2ixyz_gpu, ig2ixyz.data(), this->npw);
}
#endif
return;
Expand Down
Loading
Loading