Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c73bd57
add the basic func of the file
A-006 Nov 5, 2024
e838161
modify the Makefile
A-006 Nov 5, 2024
75e42ba
delete file
A-006 Nov 5, 2024
8d73d7e
modify the position of the new fft
A-006 Nov 5, 2024
5304827
modify the Makefile
A-006 Nov 5, 2024
4049c76
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
5cfd6bc
add the cpu float in the fft floder
A-006 Nov 5, 2024
9b8fb19
change the test file
A-006 Nov 5, 2024
d7f50df
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 5, 2024
9ad3c19
add the func in test
A-006 Nov 5, 2024
dfaad66
add the float fft
A-006 Nov 5, 2024
b40629d
change ft into ft1
A-006 Nov 5, 2024
965d627
add the file of the float_define and the device set
A-006 Nov 5, 2024
39e0d0c
delete the memory allocate in the ft
A-006 Nov 6, 2024
b183967
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 6, 2024
2a45b32
add the Smart Pointer and the logic gate
A-006 Nov 7, 2024
28419a6
modify the position of the FFT
A-006 Nov 7, 2024
d29b355
change fft_bundle name
A-006 Nov 7, 2024
6cc4bac
save version of the pw_test and single version
A-006 Nov 7, 2024
2d0b5f3
fix complie bug and change the fftwf logic
A-006 Nov 8, 2024
0610758
Merge branch 'develop' into fft1
A-006 Nov 8, 2024
da26acc
add comments for the fft class
A-006 Nov 8, 2024
9b200a2
modify the fft name and add comments
A-006 Nov 8, 2024
f07a97d
modify the Makefile
A-006 Nov 8, 2024
2623735
Merge branch 'develop' into fft1
A-006 Nov 9, 2024
61bb766
Merge branch 'develop' into fft1
mohanchen Nov 10, 2024
9c1a22d
update the file
A-006 Nov 11, 2024
9029453
update the format
A-006 Nov 11, 2024
479b178
update the shared_ptr
A-006 Nov 11, 2024
76f0d85
Merge branch 'develop' into fft1
A-006 Nov 11, 2024
00dfee5
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scf/pw_Si2/INPUT
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ scf_thr 1e-7
scf_nmax 100
device cpu
ks_solver dav_subspace
precision double
precision single
5 changes: 4 additions & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ VPATH=./src_global:\
./module_base/module_mixing:\
./module_md:\
./module_basis/module_pw:\
./module_basis/module_pw/module_fft:\
./module_esolver:\
./module_hsolver:\
./module_hsolver/kernels:\
Expand Down Expand Up @@ -168,7 +169,6 @@ OBJS_BASE=abfs-vector3_order.o\
memory_op.o\
device.o\


OBJS_CELL=atom_pseudo.o\
atom_spec.o\
pseudo.o\
Expand Down Expand Up @@ -414,6 +414,9 @@ OBJS_PSI_INITIALIZER=psi_initializer.o\
psi_initializer_nao_random.o\

OBJS_PW=fft.o\
fft_bundle.o\
fft_base.o\
fft_cpu.o\
pw_basis.o\
pw_basis_k.o\
pw_basis_sup.o\
Expand Down
1 change: 0 additions & 1 deletion source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ list (APPEND LIBM_SRC
libm/sincos.cpp
)
endif()

add_library(
base
OBJECT
Expand Down
9 changes: 9 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
if (ENABLE_FLOAT_FFTW)
list (APPEND FFT_SRC
module_fft/fft_cpu_float.cpp
)
endif()
list(APPEND objects
fft.cpp
pw_basis.cpp
Expand All @@ -10,6 +15,10 @@ list(APPEND objects
pw_init.cpp
pw_transform.cpp
pw_transform_k.cpp
module_fft/fft_base.cpp
module_fft/fft_bundle.cpp
module_fft/fft_cpu.cpp
${FFT_SRC}
)

add_library(
Expand Down
55 changes: 28 additions & 27 deletions source/module_basis/module_pw/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
this->fftny = this->ny = ny_in;
if (this->gamma_only)
{
if (xprime)
if (xprime) {
this->fftnx = int(nx / 2) + 1;
else
} else {
this->fftny = int(ny / 2) + 1;
}
}
this->nz = nz_in;
this->ns = ns_in;
Expand All @@ -92,10 +93,10 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
int maxgrids = (nsz > nrxx) ? nsz : nrxx;
if (!this->mpifft)
{
z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
d_rspace = (double*)z_auxg;
// z_auxg = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// z_auxr = (std::complex<double>*)fftw_malloc(sizeof(fftw_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid", 2 * sizeof(fftw_complex) * maxgrids);
// d_rspace = (double*)z_auxg;
// auxr_3d = static_cast<std::complex<double> *>(
// fftw_malloc(sizeof(fftw_complex) * (this->nx * this->ny * this->nz)));
#if defined(__CUDA) || defined(__ROCM)
Expand All @@ -105,15 +106,15 @@ void FFT::initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int
resmem_zd_op()(gpu_ctx, this->z_auxr_3d, this->nx * this->ny * this->nz);
}
#endif // defined(__CUDA) || defined(__ROCM)
#if defined(__ENABLE_FLOAT_FFTW)
if (this->precision == "single")
{
c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
s_rspace = (float*)c_auxg;
}
#endif // defined(__ENABLE_FLOAT_FFTW)
// #if defined(__ENABLE_FLOAT_FFTW)
// if (this->precision == "single")
// {
// c_auxg = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// c_auxr = (std::complex<float>*)fftw_malloc(sizeof(fftwf_complex) * maxgrids);
// ModuleBase::Memory::record("FFT::grid_s", 2 * sizeof(fftwf_complex) * maxgrids);
// s_rspace = (float*)c_auxg;
// }
// #endif // defined(__ENABLE_FLOAT_FFTW)
}
else
{
Expand Down Expand Up @@ -353,62 +354,62 @@ void FFT::cleanFFT()
if (planzfor)
{
fftw_destroy_plan(planzfor);
planzfor = NULL;
planzfor = nullptr;
}
if (planzbac)
{
fftw_destroy_plan(planzbac);
planzbac = NULL;
planzbac = nullptr;
}
if (planxfor1)
{
fftw_destroy_plan(planxfor1);
planxfor1 = NULL;
planxfor1 = nullptr;
}
if (planxbac1)
{
fftw_destroy_plan(planxbac1);
planxbac1 = NULL;
planxbac1 = nullptr;
}
if (planxfor2)
{
fftw_destroy_plan(planxfor2);
planxfor2 = NULL;
planxfor2 = nullptr;
}
if (planxbac2)
{
fftw_destroy_plan(planxbac2);
planxbac2 = NULL;
planxbac2 = nullptr;
}
if (planyfor)
{
fftw_destroy_plan(planyfor);
planyfor = NULL;
planyfor = nullptr;
}
if (planybac)
{
fftw_destroy_plan(planybac);
planybac = NULL;
planybac = nullptr;
}
if (planxr2c)
{
fftw_destroy_plan(planxr2c);
planxr2c = NULL;
planxr2c = nullptr;
}
if (planxc2r)
{
fftw_destroy_plan(planxc2r);
planxc2r = NULL;
planxc2r = nullptr;
}
if (planyr2c)
{
fftw_destroy_plan(planyr2c);
planyr2c = NULL;
planyr2c = nullptr;
}
if (planyc2r)
{
fftw_destroy_plan(planyc2r);
planyc2r = NULL;
planyc2r = nullptr;
}

// fftw_destroy_plan(this->plan3dforward);
Expand Down
17 changes: 17 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "fft_base.h"
namespace ModulePW
{
template <typename FPTYPE>
FFT_BASE<FPTYPE>::FFT_BASE()
{
}
template <typename FPTYPE>
FFT_BASE<FPTYPE>::~FFT_BASE()
{
}

template FFT_BASE<float>::FFT_BASE();
template FFT_BASE<double>::FFT_BASE();
template FFT_BASE<float>::~FFT_BASE();
template FFT_BASE<double>::~FFT_BASE();
}
62 changes: 62 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <complex>
#include <string>
#include "fftw3.h"
#ifndef FFT_BASE_H
#define FFT_BASE_H
namespace ModulePW
{
template <typename FPTYPE>
class FFT_BASE
{
public:

FFT_BASE();
virtual ~FFT_BASE();

// init parameters of fft
virtual __attribute__((weak))
void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in,
int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false);

//init fftw_plans
virtual void setupFFT()=0;

//destroy fftw_plans
virtual void cleanFFT()=0;
//clear fftw_data
virtual void clear()=0;

// access the real space data
virtual __attribute__((weak)) FPTYPE* get_rspace_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxg_data() const;

virtual __attribute__((weak)) std::complex<FPTYPE>* get_auxr_3d_data() const;

//forward fft in x-y direction
virtual __attribute__((weak)) void fftxyfor(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxybac(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftzfor(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftzbac(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxyr2c(FPTYPE* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const;

virtual __attribute__((weak)) void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

virtual __attribute__((weak)) void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;

protected:
int ny=0;
int nx=0;
int nz=0;

};
}
#endif // FFT_BASE_H
Loading
Loading