Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7b855e6
add unit test
A-006 Apr 8, 2025
4c779e8
add intergrate test
A-006 Apr 8, 2025
f9e7710
fix process
A-006 Apr 8, 2025
a8d72af
modify jd
A-006 Apr 8, 2025
5faf27e
update bug
A-006 Apr 8, 2025
4773008
set fftw float
A-006 Apr 8, 2025
503bf58
add the float BPCG
A-006 Apr 8, 2025
4b5df98
add float test
A-006 Apr 8, 2025
16cb172
fix compile bug
A-006 Apr 8, 2025
f5c1fc1
fix error
A-006 Apr 9, 2025
677e5d6
fix the compile test
A-006 Apr 9, 2025
98decc4
Merge branch 'develop' into fft_float2
A-006 Apr 9, 2025
f632326
Merge branch 'develop' into fft_float2
A-006 Apr 10, 2025
300713c
add
A-006 Apr 10, 2025
1dbacf8
remove the test file
A-006 Apr 18, 2025
f565945
change the file
A-006 Apr 18, 2025
e1601ee
revert bug
A-006 Apr 18, 2025
f6fd16d
set the float type
A-006 Apr 18, 2025
bed7852
Merge branch 'develop' into fft_float2
A-006 Apr 18, 2025
80344ac
reset the FFT_MEASURE
A-006 Apr 18, 2025
c60bf81
update unittest
A-006 Apr 18, 2025
ed18346
change readme
A-006 Apr 22, 2025
1f66367
update threashold
A-006 Apr 22, 2025
4c63669
Merge branch 'develop' into fft_float2
A-006 Apr 22, 2025
7553e06
use the test file
A-006 Apr 22, 2025
385b010
fix unresonable comments
A-006 Apr 22, 2025
2e13c7f
update eslover before all runners
A-006 Apr 27, 2025
2bf18b9
Merge branch 'develop' into fft_float2
A-006 Apr 27, 2025
a224da7
fix compile bug
A-006 Apr 27, 2025
59b73f5
fix bug
A-006 Apr 27, 2025
f750e10
Merge branch 'develop' into fft_float2
mohanchen Apr 29, 2025
d193075
update README
A-006 May 6, 2025
a9b53a1
change chebyshev MPI part
A-006 May 6, 2025
5c156e5
Merge branch 'develop' into fft_float2
A-006 May 8, 2025
d5084f6
add new test
A-006 May 8, 2025
aa443f1
delete old test
A-006 May 8, 2025
2d2a550
remove old tests
A-006 May 9, 2025
b1f144e
add change
A-006 May 13, 2025
c60d13f
Merge branch 'develop' into fft_float2
A-006 May 13, 2025
df3c712
update tick
A-006 May 13, 2025
ca1b0d9
add back marco
A-006 May 13, 2025
5ca14cd
update change
A-006 May 14, 2025
be074ab
Merge branch 'develop' into fft_float2
A-006 May 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:

- name: Configure
run: |
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DENABLE_FLOAT_FFTW=ON

# Temporarily removed because no one maintains this now.
# And it will break the CI test workflow.
Expand Down
3 changes: 3 additions & 0 deletions source/module_base/test/math_chebyshev_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,8 @@ TEST_F(MathChebyshevTest, tracepolyA_float)

TEST_F(MathChebyshevTest, checkconverge_float)
{
#ifdef __MPI
#undef __MPI
const int norder = 100;
p_fchetest = new ModuleBase::Chebyshev<float>(norder);

Expand All @@ -648,5 +650,6 @@ TEST_F(MathChebyshevTest, checkconverge_float)

delete[] v;
delete p_fchetest;
#endif
}
#endif
12 changes: 12 additions & 0 deletions source/module_basis/module_pw/module_fft/fft_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "fft_cpu.h"
#include "fftw3.h"
#include "module_base/timer.h"
namespace ModulePW
{

Expand Down Expand Up @@ -347,18 +348,22 @@ void FFT_CPU<double>::fftxyfor(std::complex<double>* in, std::complex<double>* o
int npy = this->nplane * this->ny;
if (this->xprime)
{

fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out);
#pragma omp parallel for
for (int i = 0; i < this->lixy + 1; ++i)
{
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
}
#pragma omp parallel for
for (int i = rixy; i < this->nx; ++i)
{
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
}
}
else
{
#pragma omp parallel for
for (int i = 0; i < this->nx; ++i)
{
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
Expand All @@ -374,10 +379,12 @@ void FFT_CPU<double>::fftxybac(std::complex<double>* in,std::complex<double>* ou
int npy = this->nplane * this->ny;
if (this->xprime)
{
#pragma omp parallel for
for (int i = 0; i < this->lixy + 1; ++i)
{
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
}
#pragma omp parallel for
for (int i = rixy; i < this->nx; ++i)
{
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
Expand All @@ -388,6 +395,7 @@ void FFT_CPU<double>::fftxybac(std::complex<double>* in,std::complex<double>* ou
{
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out);
fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]);
#pragma omp parallel for
for (int i = 0; i < this->nx; ++i)
{
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
Expand All @@ -414,13 +422,15 @@ void FFT_CPU<double>::fftxyr2c(double* in, std::complex<double>* out) const
if (this->xprime)
{
fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out);
#pragma omp parallel for
for (int i = 0; i < this->lixy + 1; ++i)
{
fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]);
}
}
else
{
#pragma omp parallel for
for (int i = 0; i < this->nx; ++i)
{
fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]);
Expand All @@ -435,6 +445,7 @@ void FFT_CPU<double>::fftxyc2r(std::complex<double> *in,double *out) const
int npy = this->nplane * this->ny;
if (this->xprime)
{
#pragma omp parallel for
for (int i = 0; i < this->lixy + 1; ++i)
{
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]);
Expand All @@ -444,6 +455,7 @@ void FFT_CPU<double>::fftxyc2r(std::complex<double> *in,double *out) const
else
{
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in);
#pragma omp parallel for
for (int i = 0; i < this->nx; ++i)
{
fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]);
Expand Down
19 changes: 17 additions & 2 deletions source/module_basis/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,24 @@ PW_Basis::PW_Basis()

PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
classname="PW_Basis";
this->fft_bundle.setfft("cpu",this->precision);
std::string fft_precison;
if ((this->precision=="single") || (this->precision=="mixing"))
{
fft_precison = "mixing";
}
else if (this->precision=="double")
{
fft_precison = "double";
}
#if (not defined(__ENABLE_FLOAT_FFTW) and (defined(__CUDA) || defined(__RCOM)))
if (this->device == "gpu")
{
fft_precison = "double";
}
#endif
this->fft_bundle.setfft("cpu",fft_precison);
this->double_data_ = (this->precision == "double") || (this->precision == "mixing");
this->float_data_ = (this->precision == "single") || (this->precision == "mixing");
this->float_data_ = (this->precision == "single") || (this->precision == "mixing");
}

PW_Basis:: ~PW_Basis()
Expand Down
6 changes: 3 additions & 3 deletions source/module_basis/module_pw/pw_basis_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ void PW_Basis_K::setuptransform()
this->getstartgr();
this->setupIndGk();
this->fft_bundle.clear();
std::string fft_device = this->device;
#if defined(__DSP)
this->fft_bundle.setfft("dsp", this->precision);
#else
this->fft_bundle.setfft(this->device, this->precision);
fft_device = "dsp";
#endif
this->fft_bundle.setfft(fft_device, this->precision);
if (this->xprime)
{
this->fft_bundle.initfft(this->nx,
Expand Down
4 changes: 2 additions & 2 deletions source/module_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
template <typename T>
void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
{
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");
ModuleBase::timer::tick(this->classname, "gathers_scatterp");

if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot,
{
Expand Down Expand Up @@ -185,7 +185,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
}

#endif
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");
ModuleBase::timer::tick(this->classname, "gathers_scatterp");
return;
}

Expand Down
2 changes: 1 addition & 1 deletion source/module_basis/module_pw/pw_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
#endif
for (int i = 0; i < this->nst * this->nz; ++i)
{
fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<double>(0, 0);
fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<FPTYPE>(0, 0);
}

#ifdef _OPENMP
Expand Down
5 changes: 4 additions & 1 deletion source/module_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
{
ModuleBase::timer::tick(this->classname, "recip2real");
assert(this->gamma_only == false);
ModuleBase::timer::tick("fftxybac", "recip2real_init");
ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxg_data<FPTYPE>(), this->nst * this->nz);

const int startig = ik * this->npwk_max;
Expand All @@ -182,12 +183,13 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
{
auxg[this->igl2isz_k[igl + startig]] = in[igl];
}
ModuleBase::timer::tick("fftxybac", "recip2real_init");
this->fft_bundle.fftzbac(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());

this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>());

this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

ModuleBase::timer::tick("fftxybac", "recip2real_back");
auto* auxr = this->fft_bundle.get_auxr_data<FPTYPE>();
if (add)
{
Expand All @@ -209,6 +211,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
out[ir] = auxr[ir];
}
}
ModuleBase::timer::tick("fftxybac", "recip2real_back");
ModuleBase::timer::tick(this->classname, "recip2real");
}

Expand Down
3 changes: 1 addition & 2 deletions source/module_basis/module_pw/test/pw_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ class TestEnv : public testing::Environment

int main(int argc, char **argv)
{

int kpar;
kpar = 1;
#ifdef __ENABLE_FLOAT_FFTW
precision_flag = "single";
precision_flag = "mixing";
#else
precision_flag = "double";
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ TEST_F(PWBasisKTEST,Constructor)
EXPECT_EQ(basis_k2.precision,"double");
EXPECT_EQ(basis_k2.fft_bundle.precision,"double");
ModulePW::PW_Basis_K basis_k3(device_flag, precision_single);
EXPECT_EQ(basis_k3.fft_bundle.precision,"single");
EXPECT_EQ(basis_k3.precision,"single");
EXPECT_EQ(basis_k3.fft_bundle.precision,"mixing");
}

TEST_F(PWBasisKTEST,Initgrids1)
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ namespace ModuleESolver
ESolver_FP::ESolver_FP()
{
std::string fft_device = PARAM.inp.device;

// LCAO basis doesn't support GPU acceleration on FFT currently
if(PARAM.inp.basis_type == "lcao")
{
fft_device = "cpu";
}

pw_rho = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
if (PARAM.globalv.double_grid)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ void Structure_Factor::setup_structure_factor(const UnitCell* Ucell, const Paral
// std::ofstream ofs( outstr.c_str() ) ;
bool usebspline;
if(nbspline > 0) { usebspline = true;
} else { usebspline = false;
}
} else { usebspline = false;}

if(usebspline)
{
Expand Down Expand Up @@ -147,6 +146,7 @@ void Structure_Factor::setup_structure_factor(const UnitCell* Ucell, const Paral
inat++;
}
}

if (device == "gpu") {
if (PARAM.globalv.has_float_data) {
resmem_cd_op()(this->c_eigts1, Ucell->nat * (2 * rho_basis->nx + 1));
Expand Down
28 changes: 28 additions & 0 deletions source/module_hamilt_pw/hamilt_pwdft/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ remove_definitions(-D__DEEPKS)
remove_definitions(-D__CUDA)
remove_definitions(-D__ROCM)
remove_definitions(-D__EXX)
remove_definitions(-DUSE_PAW)

AddTest(
TARGET pwdft_soc
Expand All @@ -26,4 +27,31 @@ AddTest(
TARGET radial_proj_test
LIBS parameter base device ${math_libs}
SOURCES radial_proj_test.cpp ../radial_proj.cpp
)

AddTest(
TARGET structure_factor_test
LIBS parameter ${math_libs} base device planewave
SOURCES structure_factor_test.cpp ../structure_factor.cpp ../parallel_grid.cpp
../../../module_cell/unitcell.cpp
../../../module_io/output.cpp
../../../module_cell/update_cell.cpp
../../../module_cell/bcast_cell.cpp
../../../module_cell/print_cell.cpp
../../../module_cell/atom_spec.cpp
../../../module_cell/atom_pseudo.cpp
../../../module_cell/pseudo.cpp
../../../module_cell/read_stru.cpp
../../../module_cell/read_atom_species.cpp
../../../module_cell/read_atoms.cpp
../../../module_cell/read_pp.cpp
../../../module_cell/read_pp_complete.cpp
../../../module_cell/read_pp_upf100.cpp
../../../module_cell/read_pp_upf201.cpp
../../../module_cell/read_pp_vwr.cpp
../../../module_cell/read_pp_blps.cpp
../../../module_elecstate/read_pseudo.cpp
../../../module_elecstate/cal_wfc.cpp
../../../module_elecstate/cal_nelec_nband.cpp
../../../module_elecstate/read_orb.cpp
)
Loading
Loading