Skip to content

Commit 15bd8c8

Browse files
A-006mohanchen
andauthored
Resolve the segmentation fault occurring in the pw float implementation (#6130)
* add unit test * add intergrate test * fix process * modify jd * update bug * set fftw float * add the float BPCG * add float test * fix compile bug * fix error * fix the compile test * add * remove the test file * change the file * revert bug * set the float type * reset the FFT_MEASURE * update unittest * change readme * update threashold * use the test file * fix unresonable comments * update eslover before all runners * fix compile bug * fix bug * update README * change chebyshev MPI part * add new test * delete old test * remove old tests * add change * update tick * add back marco * update change --------- Co-authored-by: Mohan Chen <[email protected]>
1 parent 281d2a2 commit 15bd8c8

File tree

39 files changed

+610
-63
lines changed

39 files changed

+610
-63
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
3232
- name: Configure
3333
run: |
34-
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1
34+
cmake -B build -DBUILD_TESTING=ON -DENABLE_DEEPKS=ON -DENABLE_MLKEDF=ON -DENABLE_LIBXC=ON -DENABLE_LIBRI=ON -DENABLE_PAW=ON -DENABLE_GOOGLEBENCH=ON -DENABLE_RAPIDJSON=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DENABLE_FLOAT_FFTW=ON
3535
3636
# Temporarily removed because no one maintains this now.
3737
# And it will break the CI test workflow.

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ __pycache__
2323
abacus.json
2424
*.npy
2525
toolchain/install/
26-
toolchain/abacus_env.sh
26+
toolchain/abacus_env.sh

source/module_base/test/math_chebyshev_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
* - calfinalvec_real
1515
* - calfinalvec_complex
1616
* - tracepolyA
17-
* - checkconverge
18-
*
19-
*
2017
*/
2118
class toolfunc
2219
{
@@ -625,6 +622,8 @@ TEST_F(MathChebyshevTest, tracepolyA_float)
625622

626623
TEST_F(MathChebyshevTest, checkconverge_float)
627624
{
625+
#ifdef __MPI
626+
#undef __MPI
628627
const int norder = 100;
629628
p_fchetest = new ModuleBase::Chebyshev<float>(norder);
630629

@@ -648,5 +647,6 @@ TEST_F(MathChebyshevTest, checkconverge_float)
648647

649648
delete[] v;
650649
delete p_fchetest;
650+
#endif
651651
}
652652
#endif

source/module_base/test_parallel/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ AddTest(
4040
SOURCES test_para_gemm.cpp
4141
)
4242

43+
AddTest(
44+
TARGET base_math_chebyshev_mpi
45+
LIBS MPI::MPI_CXX parameter ${math_libs} base device container
46+
SOURCES math_chebyshev_mpi_test.cpp
47+
)
48+
4349
add_test(NAME base_para_gemm_parallel
4450
COMMAND mpirun -np 4 ./base_para_gemm
4551
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#include "../math_chebyshev.h"
2+
#include "mpi.h"
3+
#include "module_base/parallel_comm.h"
4+
#include "gmock/gmock.h"
5+
#include "gtest/gtest.h"
6+
/************************************************
7+
* unit test of class Chebyshev MPI part
8+
***********************************************/
9+
10+
/**
11+
* - Tested Functions:
12+
* - checkconverge
13+
*/
14+
class toolfunc
15+
{
16+
public:
17+
double x7(double x)
18+
{
19+
return pow(x, 7);
20+
}
21+
double x6(double x)
22+
{
23+
return pow(x, 6);
24+
}
25+
double expr(double x)
26+
{
27+
return exp(x);
28+
}
29+
std::complex<double> expi(std::complex<double> x)
30+
{
31+
const std::complex<double> j(0.0, 1.0);
32+
return exp(j * x);
33+
}
34+
std::complex<double> expi2(std::complex<double> x)
35+
{
36+
const std::complex<double> j(0.0, 1.0);
37+
const double PI = 3.14159265358979323846;
38+
return exp(j * PI / 2.0 * x);
39+
}
40+
// Pauli matrix: [0,-i;i,0]
41+
int LDA = 2;
42+
double factor = 1;
43+
void sigma_y(std::complex<double>* spin_in, std::complex<double>* spin_out, const int m = 1)
44+
{
45+
const std::complex<double> j(0.0, 1.0);
46+
if (this->LDA < 2) {
47+
this->LDA = 2;
48+
}
49+
for (int i = 0; i < m; ++i)
50+
{
51+
spin_out[LDA * i] = -factor * j * spin_in[LDA * i + 1];
52+
spin_out[LDA * i + 1] = factor * j * spin_in[LDA * i];
53+
}
54+
}
55+
#ifdef __ENABLE_FLOAT_FFTW
56+
float x7(float x)
57+
{
58+
return pow(x, 7);
59+
}
60+
float x6(float x)
61+
{
62+
return pow(x, 6);
63+
}
64+
float expr(float x)
65+
{
66+
return exp(x);
67+
}
68+
std::complex<float> expi(std::complex<float> x)
69+
{
70+
const std::complex<float> j(0.0, 1.0);
71+
return exp(j * x);
72+
}
73+
std::complex<float> expi2(std::complex<float> x)
74+
{
75+
const std::complex<float> j(0.0, 1.0);
76+
const float PI = 3.14159265358979323846;
77+
return exp(j * PI / 2.0f * x);
78+
}
79+
// Pauli matrix: [0,-i;i,0]
80+
void sigma_y(std::complex<float>* spin_in, std::complex<float>* spin_out, const int m = 1)
81+
{
82+
const std::complex<float> j(0.0, 1.0);
83+
if (this->LDA < 2)
84+
this->LDA = 2;
85+
for (int i = 0; i < m; ++i)
86+
{
87+
spin_out[LDA * i] = -j * spin_in[LDA * i + 1];
88+
spin_out[LDA * i + 1] = j * spin_in[LDA * i];
89+
}
90+
}
91+
#endif
92+
};
93+
class MathChebyshevTest : public testing::Test
94+
{
95+
protected:
96+
ModuleBase::Chebyshev<double>* p_chetest;
97+
ModuleBase::Chebyshev<float>* p_fchetest;
98+
toolfunc fun;
99+
int dsize = 0;
100+
int my_rank = 0;
101+
void SetUp() override
102+
{
103+
int world_rank;
104+
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
105+
int world_size;
106+
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
107+
108+
int color = (world_rank < world_size / 2) ? 0 : 1;
109+
int key = world_rank;
110+
111+
MPI_Comm_split(MPI_COMM_WORLD, color, key, &POOL_WORLD);
112+
113+
int pool_rank, pool_size;
114+
MPI_Comm_rank(POOL_WORLD, &pool_rank);
115+
MPI_Comm_size(POOL_WORLD, &pool_size);
116+
}
117+
void TearDown() override
118+
{
119+
}
120+
};
121+
122+
TEST_F(MathChebyshevTest, checkconverge)
123+
{
124+
const int norder = 100;
125+
p_chetest = new ModuleBase::Chebyshev<double>(norder);
126+
auto fun_sigma_y
127+
= [&](std::complex<double>* in, std::complex<double>* out, const int m = 1) { fun.sigma_y(in, out, m); };
128+
129+
std::complex<double>* v = new std::complex<double>[4];
130+
v[0] = 1.0;
131+
v[1] = 0.0;
132+
v[2] = 0.0;
133+
v[3] = 1.0; //[1 0; 0 1]
134+
double tmin = -1.1;
135+
double tmax = 1.1;
136+
bool converge;
137+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
138+
EXPECT_TRUE(converge);
139+
converge = p_chetest->checkconverge(fun_sigma_y, v + 2, 2, 2, tmax, tmin, 0.2);
140+
EXPECT_TRUE(converge);
141+
EXPECT_NEAR(tmin, -1.1, 1e-8);
142+
EXPECT_NEAR(tmax, 1.1, 1e-8);
143+
144+
tmax = -1.1;
145+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 2.2);
146+
EXPECT_TRUE(converge);
147+
EXPECT_NEAR(tmin, -1.1, 1e-8);
148+
EXPECT_NEAR(tmax, 1.1, 1e-8);
149+
150+
// not converge
151+
v[0] = std::complex<double>(0, 1), v[1] = 1;
152+
fun.factor = 1.5;
153+
tmin = -1.1, tmax = 1.1;
154+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
155+
EXPECT_FALSE(converge);
156+
157+
fun.factor = -1.5;
158+
tmin = -1.1, tmax = 1.1;
159+
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
160+
EXPECT_FALSE(converge);
161+
fun.factor = 1;
162+
163+
delete[] v;
164+
delete p_chetest;
165+
}
166+
167+
#ifdef __ENABLE_FLOAT_FFTW
168+
TEST_F(MathChebyshevTest, checkconverge_float)
169+
{
170+
const int norder = 100;
171+
p_fchetest = new ModuleBase::Chebyshev<float>(norder);
172+
173+
std::complex<float>* v = new std::complex<float>[4];
174+
v[0] = 1.0;
175+
v[1] = 0.0;
176+
v[2] = 0.0;
177+
v[3] = 1.0; //[1 0; 0 1]
178+
float tmin = -1.1;
179+
float tmax = 1.1;
180+
bool converge;
181+
182+
auto fun_sigma_yf
183+
= [&](std::complex<float>* in, std::complex<float>* out, const int m = 1) { fun.sigma_y(in, out, m); };
184+
converge = p_fchetest->checkconverge(fun_sigma_yf, v, 2, 2, tmax, tmin, 0.2);
185+
EXPECT_TRUE(converge);
186+
converge = p_fchetest->checkconverge(fun_sigma_yf, v + 2, 2, 2, tmax, tmin, 0.2);
187+
EXPECT_TRUE(converge);
188+
EXPECT_NEAR(tmin, -1.1, 1e-6);
189+
EXPECT_NEAR(tmax, 1.1, 1e-6);
190+
191+
delete[] v;
192+
delete p_fchetest;
193+
}
194+
#endif
195+
196+
int main(int argc, char** argv)
197+
{
198+
#ifdef __MPI
199+
MPI_Init(&argc, &argv);
200+
#endif
201+
testing::InitGoogleTest(&argc, argv);
202+
int result = RUN_ALL_TESTS();
203+
#ifdef __MPI
204+
MPI_Finalize();
205+
#endif
206+
return result;
207+
}

source/module_basis/module_pw/module_fft/fft_cpu.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,18 +347,22 @@ void FFT_CPU<double>::fftxyfor(std::complex<double>* in, std::complex<double>* o
347347
int npy = this->nplane * this->ny;
348348
if (this->xprime)
349349
{
350+
350351
fftw_execute_dft(this->planxfor1, (fftw_complex*)in, (fftw_complex*)out);
352+
#pragma omp parallel for
351353
for (int i = 0; i < this->lixy + 1; ++i)
352354
{
353355
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
354356
}
357+
#pragma omp parallel for
355358
for (int i = rixy; i < this->nx; ++i)
356359
{
357360
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
358361
}
359362
}
360363
else
361364
{
365+
#pragma omp parallel for
362366
for (int i = 0; i < this->nx; ++i)
363367
{
364368
fftw_execute_dft(this->planyfor, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
@@ -374,10 +378,12 @@ void FFT_CPU<double>::fftxybac(std::complex<double>* in,std::complex<double>* ou
374378
int npy = this->nplane * this->ny;
375379
if (this->xprime)
376380
{
381+
#pragma omp parallel for
377382
for (int i = 0; i < this->lixy + 1; ++i)
378383
{
379384
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
380385
}
386+
#pragma omp parallel for
381387
for (int i = rixy; i < this->nx; ++i)
382388
{
383389
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
@@ -388,6 +394,7 @@ void FFT_CPU<double>::fftxybac(std::complex<double>* in,std::complex<double>* ou
388394
{
389395
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)out);
390396
fftw_execute_dft(this->planxbac2, (fftw_complex*)&in[rixy * nplane], (fftw_complex*)&out[rixy * nplane]);
397+
#pragma omp parallel for
391398
for (int i = 0; i < this->nx; ++i)
392399
{
393400
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&out[i * npy]);
@@ -414,13 +421,15 @@ void FFT_CPU<double>::fftxyr2c(double* in, std::complex<double>* out) const
414421
if (this->xprime)
415422
{
416423
fftw_execute_dft_r2c(this->planxr2c, in, (fftw_complex*)out);
424+
#pragma omp parallel for
417425
for (int i = 0; i < this->lixy + 1; ++i)
418426
{
419427
fftw_execute_dft(this->planyfor, (fftw_complex*)&out[i * npy], (fftw_complex*)&out[i * npy]);
420428
}
421429
}
422430
else
423431
{
432+
#pragma omp parallel for
424433
for (int i = 0; i < this->nx; ++i)
425434
{
426435
fftw_execute_dft_r2c(this->planyr2c, &in[i * npy], (fftw_complex*)&out[i * npy]);
@@ -435,6 +444,7 @@ void FFT_CPU<double>::fftxyc2r(std::complex<double> *in,double *out) const
435444
int npy = this->nplane * this->ny;
436445
if (this->xprime)
437446
{
447+
#pragma omp parallel for
438448
for (int i = 0; i < this->lixy + 1; ++i)
439449
{
440450
fftw_execute_dft(this->planybac, (fftw_complex*)&in[i * npy], (fftw_complex*)&in[i * npy]);
@@ -444,6 +454,7 @@ void FFT_CPU<double>::fftxyc2r(std::complex<double> *in,double *out) const
444454
else
445455
{
446456
fftw_execute_dft(this->planxbac1, (fftw_complex*)in, (fftw_complex*)in);
457+
#pragma omp parallel for
447458
for (int i = 0; i < this->nx; ++i)
448459
{
449460
fftw_execute_dft_c2r(this->planyc2r, (fftw_complex*)&in[i * npy], &out[i * npy]);

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::mo
1717
classname="PW_Basis";
1818
this->fft_bundle.setfft("cpu",this->precision);
1919
this->double_data_ = (this->precision == "double") || (this->precision == "mixing");
20-
this->float_data_ = (this->precision == "single") || (this->precision == "mixing");
20+
this->float_data_ = (this->precision == "single") || (this->precision == "mixing");
2121
}
2222

2323
PW_Basis:: ~PW_Basis()

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ void PW_Basis_K::setuptransform()
203203
this->getstartgr();
204204
this->setupIndGk();
205205
this->fft_bundle.clear();
206+
std::string fft_device = this->device;
206207
#if defined(__DSP)
207-
this->fft_bundle.setfft("dsp", this->precision);
208-
#else
209-
this->fft_bundle.setfft(this->device, this->precision);
208+
fft_device = "dsp";
210209
#endif
210+
this->fft_bundle.setfft(fft_device, this->precision);
211211
if (this->xprime)
212212
{
213213
this->fft_bundle.initfft(this->nx,

source/module_basis/module_pw/pw_gatherscatter.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
9898
template <typename T>
9999
void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
100100
{
101-
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");
102-
101+
// ModuleBase::timer::tick(this->classname, "gathers_scatterp");
103102
if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot,
104103
{
105104
#ifdef _OPENMP
@@ -183,7 +182,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
183182
}
184183
}
185184
#endif
186-
//ModuleBase::timer::tick(this->classname, "gathers_scatterp");
185+
// ModuleBase::timer::tick(this->classname, "gathers_scatterp");
187186
return;
188187
}
189188

source/module_basis/module_pw/pw_transform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
210210
#endif
211211
for (int i = 0; i < this->nst * this->nz; ++i)
212212
{
213-
fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<double>(0, 0);
213+
fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<FPTYPE>(0, 0);
214214
}
215215

216216
#ifdef _OPENMP

0 commit comments

Comments
 (0)