Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:

- name: Configure
run: |
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DENABLE_FLOAT_FFTW=ON

- uses: pre-commit/[email protected]
with:
Expand Down
4 changes: 4 additions & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,8 @@ if(BUILD_TESTING)
add_subdirectory(test_serial)
add_subdirectory(kernels/test)
endif()
add_subdirectory(module_fft/test)
if (ENABLE_FLOAT_FFTW)
add_subdirectory(module_fft/test_fftwf)
endif()
endif()
76 changes: 39 additions & 37 deletions source/module_basis/module_pw/module_fft/fft_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@

#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include "module_base/tool_quit.h"
#if defined(__CUDA)
#include "fft_cuda.h"
#endif
#if defined(__ROCM)
#include "fft_rocm.h"
#endif

template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
{
return std::unique_ptr<FFT_BASE>(new FFT_BASE(std::forward<Args>(args)...));
}
namespace ModulePW
{
FFT_Bundle::~FFT_Bundle()
Expand Down Expand Up @@ -43,27 +39,48 @@ void FFT_Bundle::initfft(int nx_in,
assert(this->device=="cpu" || this->device=="gpu");
assert(this->precision=="single" || this->precision=="double" || this->precision=="mixing");

if (this->precision=="single")
float_flag = (this->precision=="single" || this->precision=="mixing")? true:false;
double_flag = true;

if (this->device=="gpu")
{
#if not defined (__ENABLE_FLOAT_FFTW)
if (this->device == "cpu"){
float_define = false;
}
#endif
#if defined(__CUDA) || defined (__ROCM)
if (this->device == "gpu"){
float_flag = float_define;
}
#if defined(__ROCM)
if (float_flag)
{
fft_float = make_unique<FFT_ROCM<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
}
if (double_flag)
{
fft_double = make_unique<FFT_ROCM<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
}
#elif defined(__CUDA)
if (float_flag)
{
fft_float = make_unique<FFT_CUDA<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
}
if (double_flag)
{
fft_double = make_unique<FFT_CUDA<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
}
#else
std::cout<<"wihout the CUDA OR DCU,but set the device as gpu, use cpu instead\n";
this->device="cpu";
#endif
float_flag = float_define;
double_flag = true;
}
if (this->precision=="double")
{
double_flag = true;
}

#if not defined (__ENABLE_FLOAT_FFTW)
if (this->device == "cpu" && float_flag)
{
float_define = false;
ModuleBase::WARNING_QUIT("initfft", "please complie abacus with fftw3_FLOAT");
}
#endif

if (device=="cpu")
if (this->device=="cpu")
{
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
Expand Down Expand Up @@ -94,21 +111,6 @@ void FFT_Bundle::initfft(int nx_in,
xprime_in);
}
}
if (device=="gpu")
{
#if defined(__ROCM)
fft_float = make_unique<FFT_ROCM<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_ROCM<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#elif defined(__CUDA)
fft_float = make_unique<FFT_CUDA<float>>();
fft_float->initfft(nx_in,ny_in,nz_in);
fft_double = make_unique<FFT_CUDA<double>>();
fft_double->initfft(nx_in,ny_in,nz_in);
#endif
}

}

void FFT_Bundle::setupFFT()
Expand Down
6 changes: 6 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_bundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
#include "fft_cpu.h"
namespace ModulePW
{
template<typename FFT_BASE, typename... Args>
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
{
return std::unique_ptr<FFT_BASE>(new FFT_BASE(std::forward<Args>(args)...));
}

class FFT_Bundle
{
public:
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/module_fft/fft_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ void FFT_CPU<double>::setupFFT()
template <>
void FFT_CPU<double>::clearfft(fftw_plan& plan)
{
if (plan)
if (plan!=nullptr)
{
fftw_destroy_plan(plan);
plan = nullptr;
Expand Down
50 changes: 25 additions & 25 deletions source/module_basis/module_pw/module_fft/fft_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,31 @@ class FFT_CPU : public FFT_BASE<FPTYPE>
void clearfft(fftw_plan& plan);
void clearfft(fftwf_plan& plan);

fftw_plan planzfor = NULL;
fftw_plan planzbac = NULL;
fftw_plan planxfor1 = NULL;
fftw_plan planxbac1 = NULL;
fftw_plan planxfor2 = NULL;
fftw_plan planxbac2 = NULL;
fftw_plan planyfor = NULL;
fftw_plan planybac = NULL;
fftw_plan planxr2c = NULL;
fftw_plan planxc2r = NULL;
fftw_plan planyr2c = NULL;
fftw_plan planyc2r = NULL;

fftwf_plan planfzfor = NULL;
fftwf_plan planfzbac = NULL;
fftwf_plan planfxfor1= NULL;
fftwf_plan planfxbac1= NULL;
fftwf_plan planfxfor2= NULL;
fftwf_plan planfxbac2= NULL;
fftwf_plan planfyfor = NULL;
fftwf_plan planfybac = NULL;
fftwf_plan planfxr2c = NULL;
fftwf_plan planfxc2r = NULL;
fftwf_plan planfyr2c = NULL;
fftwf_plan planfyc2r = NULL;
fftw_plan planzfor = nullptr;
fftw_plan planzbac = nullptr;
fftw_plan planxfor1 = nullptr;
fftw_plan planxbac1 = nullptr;
fftw_plan planxfor2 = nullptr;
fftw_plan planxbac2 = nullptr;
fftw_plan planyfor = nullptr;
fftw_plan planybac = nullptr;
fftw_plan planxr2c = nullptr;
fftw_plan planxc2r = nullptr;
fftw_plan planyr2c = nullptr;
fftw_plan planyc2r = nullptr;

fftwf_plan planfzfor = nullptr;
fftwf_plan planfzbac = nullptr;
fftwf_plan planfxfor1= nullptr;
fftwf_plan planfxbac1= nullptr;
fftwf_plan planfxfor2= nullptr;
fftwf_plan planfxbac2= nullptr;
fftwf_plan planfyfor = nullptr;
fftwf_plan planfybac = nullptr;
fftwf_plan planfxr2c = nullptr;
fftwf_plan planfxc2r = nullptr;
fftwf_plan planfyr2c = nullptr;
fftwf_plan planfyc2r = nullptr;

std::complex<float>*c_auxg = nullptr;
std::complex<float>*c_auxr = nullptr; // fft space
Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/module_fft/fft_cpu_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void FFT_CPU<float>::setupFFT()
template <>
void FFT_CPU<float>::clearfft(fftwf_plan& plan)
{
if (plan)
if (plan!=nullptr)
{
fftwf_destroy_plan(plan);
plan = nullptr;
Expand Down
16 changes: 16 additions & 0 deletions source/module_basis/module_pw/module_fft/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
remove_definitions(-D__CUDA)
remove_definitions(-D__RCOM)

AddTest(
TARGET fft_bundle_without_fftwf_test
LIBS parameter ${math_libs} base device FFTW3::FFTW3
SOURCES fft_bundle_without_fftwf_test.cpp
../fft_bundle.cpp ../fft_cpu.cpp
)

AddTest(
TARGET fft_test_cpu
LIBS parameter ${math_libs} base device FFTW3::FFTW3
SOURCES fft_test_cpu.cpp
../fft_bundle.cpp ../fft_cpu.cpp
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "malloc.h"
#include "fstream"
#define private public
#include "../fft_base.h"
#include "../fft_bundle.h"
#include "../fft_cpu.h"
#include "module_parameter/parameter.h"
#undef private
namespace ModulePW
{
class FftBundleTest : public ::testing::Test
{
protected:
FFT_Bundle fft_bundle;
std::string output;
void SetUp() override
{
}
void TearDown() override
{
}
};

TEST_F(FftBundleTest,setfft)
{
fft_bundle.setfft("cpu","single");
EXPECT_EQ(fft_bundle.device,"cpu");
EXPECT_EQ(fft_bundle.precision,"single");

fft_bundle.setfft("gpu","double");
EXPECT_EQ(fft_bundle.device,"gpu");
EXPECT_EQ(fft_bundle.precision,"double");
}

TEST_F(FftBundleTest,initfft)
{
fft_bundle.setfft("cpu","single");
testing::internal::CaptureStdout();
EXPECT_EXIT(fft_bundle.initfft(16,16,16,7,8,256,16,1,false,true),
::testing::ExitedWithCode(1),"");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output,testing::HasSubstr("complie"));

fft_bundle.setfft("cpu","double");
fft_bundle.initfft(16,16,16,7,8,256,16,1,false,true);
EXPECT_EQ(fft_bundle.float_flag,false);
EXPECT_EQ(fft_bundle.double_flag,true);

fft_bundle.setfft("gpu","double");
fft_bundle.initfft(16,16,16,7,8,256,16,1,false,true);
EXPECT_EQ(fft_bundle.float_flag,false);
EXPECT_EQ(fft_bundle.double_flag,true);

fft_bundle.setfft("gpu","single");
testing::internal::CaptureStdout();
EXPECT_EXIT(fft_bundle.initfft(16,16,16,7,8,256,16,1,false,true),
::testing::ExitedWithCode(1),"");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output,testing::HasSubstr("complie"));

}
}
Loading
Loading