Skip to content

Commit b1fe0dd

Browse files
committed
Merge branch 'develop' of https://github.com/mohanchen/abacus-mc into develop
2 parents 3c3289f + 6c559be commit b1fe0dd

File tree

12 files changed

+204
-80
lines changed

12 files changed

+204
-80
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ if(ENABLE_MPI)
269269
list(APPEND math_libs MPI::MPI_CXX)
270270
endif()
271271

272+
272273
if (USE_DSP)
273274
add_compile_definitions(__DSP)
274275
target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1})

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class FFT_BASE
1212
virtual ~FFT_BASE() {};
1313

1414
/**
15-
* @brief Initialize the fft parameters As virtual function.
15+
* @brief Initialize the fft parameters as virtual function.
1616
*
1717
* The function is used to initialize the fft parameters.
1818
*/
@@ -30,32 +30,40 @@ class FFT_BASE
3030
virtual __attribute__((weak)) void initfft(int nx_in, int ny_in, int nz_in);
3131

3232
/**
33-
* @brief Setup the fft Plan and data As pure virtual function.
33+
* @brief Setup the fft plan and data as pure virtual function.
3434
*
3535
* The function is set as pure virtual function.In order to
3636
* override the function in the derived class.In the derived
37-
* class, the function is used to setup the fft Plan and data.
37+
* class, the function is used to setup the fft plan and data.
3838
*/
3939
virtual void setupFFT() = 0;
4040

4141
/**
42-
* @brief Clean the fft Plan As pure virtual function.
42+
* @brief Clean the fft plan as pure virtual function.
4343
*
4444
* The function is set as pure virtual function.In order to
4545
* override the function in the derived class.In the derived
46-
* class, the function is used to clean the fft Plan.
46+
* class, the function is used to clean the fft plan.
4747
*/
4848
virtual void cleanFFT() = 0;
4949

5050
/**
51-
* @brief Clear the fft data As pure virtual function.
51+
* @brief Clear the fft data as pure virtual function.
5252
*
5353
* The function is set as pure virtual function.In order to
5454
* override the function in the derived class.In the derived
5555
* class, the function is used to clear the fft data.
5656
*/
5757
virtual void clear() = 0;
58-
58+
/**
59+
* @brief Allocate and destory the resoure in FFT running time,
60+
* Now it only used in the DSP mode.
61+
*
62+
* The function is set as pure virtual function.In order to
63+
* override the function in the derived class.In the derived
64+
* class, the function is used to allocate and destory the
65+
* resoure in FFT running time.
66+
*/
5967
virtual void resource_handler(const int flag) const {};
6068
/**
6169
* @brief Get the real space data in cpu-like fft

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,22 @@ void FFT_Bundle::initfft(int nx_in,
5050
if (this->precision == "single" || this->precision == "mixing")
5151
{
5252
float_flag = true;
53+
if (this->precision == "mixing")
54+
{
55+
double_flag = true;
56+
}
5357
#if not defined(__ENABLE_FLOAT_FFTW)
5458
if (this->device == "cpu")
5559
{
5660
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please enable float fftw in the cmake to use float fft");
5761
}
5862
#endif
5963
}
60-
if (this->precision == "double" || this->precision == "mixing")
64+
else if (this->precision == "double")
6165
{
6266
double_flag = true;
67+
}else{
68+
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please set the precision to single or double or mixing");
6369
}
6470
#if defined(__DSP)
6571
if (device == "dsp")
@@ -70,24 +76,23 @@ void FFT_Bundle::initfft(int nx_in,
7076
}
7177
fft_double = make_unique<FFT_DSP<double>>();
7278
fft_double->initfft(nx_in, ny_in, nz_in);
73-
}
79+
}else
7480
#endif
7581
if (device == "cpu")
7682
{
77-
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
78-
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
7983
if (float_flag)
8084
{
85+
fft_float = make_unique<FFT_CPU<float>>(this->fft_mode);
8186
fft_float
8287
->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in);
8388
}
8489
if (double_flag)
8590
{
91+
fft_double = make_unique<FFT_CPU<double>>(this->fft_mode);
8692
fft_double
8793
->initfft(nx_in, ny_in, nz_in, lixy_in, rixy_in, ns_in, nplane_in, nproc_in, gamma_only_in, xprime_in);
8894
}
89-
}
90-
if (device == "gpu")
95+
}else if (device == "gpu")
9196
{
9297
#if defined(__ROCM)
9398
fft_float = make_unique<FFT_ROCM<float>>();
@@ -100,6 +105,8 @@ void FFT_Bundle::initfft(int nx_in,
100105
fft_double = make_unique<FFT_CUDA<double>>();
101106
fft_double->initfft(nx_in, ny_in, nz_in);
102107
#endif
108+
}else{
109+
ModuleBase::WARNING_QUIT("FFT_Bundle", "Please set the device to cpu or gpu or dsp");
103110
}
104111
}
105112

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,18 @@ class FFT_Bundle
203203
std::string device = "cpu";
204204
std::string precision = "double";
205205
};
206+
// Use RAII (Resource Acquisition Is Initialization) to
207+
// control the resources used by hthread when setting the DSP
208+
struct FFT_Guard
209+
{
210+
const FFT_Bundle& fft_;
211+
FFT_Guard(const FFT_Bundle& fft) : fft_(fft)
212+
{fft_.resource_handler(1);}
213+
~FFT_Guard()
214+
{
215+
fft_.resource_handler(0);
216+
}
217+
};
218+
206219
} // namespace ModulePW
207220
#endif // FFT_H

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void FFT_DSP<double>::setupFFT()
6363
template <>
6464
void FFT_DSP<double>::resource_handler(const int flag) const
6565
{
66-
if (flag==0)
66+
if (flag == 0)
6767
{
6868
hthread_barrier_destroy(b_id);
6969
hthread_group_destroy(thread_id_for);
@@ -76,6 +76,8 @@ void FFT_DSP<double>::resource_handler(const int flag) const
7676
b_id = hthread_barrier_create(cluster_id);
7777
args_for[0] = b_id;
7878
args_back[0] = b_id;
79+
}else{
80+
ModuleBase::WARNING_QUIT("FFT_DSP", "Error use of fft resource handle");
7981
}
8082
}
8183
template <>

source/module_basis/module_pw/module_fft/fft_dsp.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
namespace ModulePW
1414
{
15+
1516
template <typename FPTYPE>
1617
class FFT_DSP : public FFT_BASE<FPTYPE>
1718
{
@@ -24,7 +25,12 @@ class FFT_DSP : public FFT_BASE<FPTYPE>
2425
void clear() override;
2526

2627
void cleanFFT() override;
27-
28+
/**
29+
* @brief Control the allocation or deallocation of hthread
30+
* resource
31+
* @param flag 0: deallocate, 1: allocate
32+
*/
33+
void resource_handler(const int flag) const override;
2834
/**
2935
* @brief Initialize the fft parameters
3036
* @param nx_in number of grid points in x direction

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class PW_Basis_K : public PW_Basis
187187
const typename GetTypeReal<TK>::type factor = 1.0) const
188188
{
189189
#if defined(__DSP)
190-
this->recip2real_dsp(in, out, ik, add, factor);
190+
this->real2recip_dsp(in, out, ik, add, factor);
191191
#else
192192
this->real2recip(in,out,ik,add,factor);
193193
#endif

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,30 @@
88
#if defined (__DSP)
99
namespace ModulePW
1010
{
11-
template <typename FPTYPE>
12-
void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
13-
std::complex<FPTYPE>* out,
11+
template <>
12+
void PW_Basis_K::real2recip_dsp(const std::complex<float>* in,
13+
std::complex<float>* out,
1414
const int ik,
1515
const bool add,
16-
const FPTYPE factor) const
16+
const float factor) const
17+
{
18+
19+
}
20+
template <>
21+
void PW_Basis_K::recip2real_dsp(const std::complex<float>* in,
22+
std::complex<float>* out,
23+
const int ik,
24+
const bool add,
25+
const float factor) const
26+
{
27+
28+
}
29+
template <>
30+
void PW_Basis_K::real2recip_dsp(const std::complex<double>* in,
31+
std::complex<double>* out,
32+
const int ik,
33+
const bool add,
34+
const double factor) const
1735
{
1836
const base_device::DEVICE_CPU* ctx;
1937
const base_device::DEVICE_GPU* gpux;
@@ -31,20 +49,20 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
3149
auxr);
3250
this->fft_bundle.resource_handler(0);
3351
// copy the result from the auxr to the out ,while consider the add
34-
set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(npw_k,
52+
set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
3553
this->nxyz,
3654
add,
3755
factor,
3856
this->ig2ixyz_k_cpu.data() + startig,
3957
auxr,
4058
out);
4159
}
42-
template <typename FPTYPE>
43-
void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
44-
std::complex<FPTYPE>* out,
60+
template <>
61+
void PW_Basis_K::recip2real_dsp(const std::complex<double>* in,
62+
std::complex<double>* out,
4563
const int ik,
4664
const bool add,
47-
const FPTYPE factor) const
65+
const double factor) const
4866
{
4967
assert(this->gamma_only == false);
5068
const base_device::DEVICE_CPU* ctx;
@@ -128,16 +146,16 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
128146
ModuleBase::timer::tick(this->classname, "convolution");
129147
}
130148

131-
// template void PW_Basis_K::real2recip_dsp<float>(const std::complex<float>* in,
132-
// std::complex<float>* out,
133-
// const int ik,
134-
// const bool add,
135-
// const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns)
136-
// template void PW_Basis_K::recip2real_dsp<float>(const std::complex<float>* in,
137-
// std::complex<float>* out,
138-
// const int ik,
139-
// const bool add,
140-
// const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny)
149+
template void PW_Basis_K::real2recip_dsp<float>(const std::complex<float>* in,
150+
std::complex<float>* out,
151+
const int ik,
152+
const bool add,
153+
const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns)
154+
template void PW_Basis_K::recip2real_dsp<float>(const std::complex<float>* in,
155+
std::complex<float>* out,
156+
const int ik,
157+
const bool add,
158+
const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny)
141159

142160
template void PW_Basis_K::real2recip_dsp<double>(const std::complex<double>* in,
143161
std::complex<double>* out,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,18 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
207207
else if (this->add_hexx_type == Add_Hexx_Type::R)
208208
{
209209
// read in Hexx(R)
210-
const std::string restart_HR_path = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
211-
bool all_exist = true;
210+
const std::string restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(GlobalV::MY_RANK);
211+
int all_exist = 1;
212212
for (int is = 0; is < PARAM.inp.nspin; ++is)
213213
{
214214
std::ifstream ifs(restart_HR_path + "_" + std::to_string(is) + ".csr");
215-
if (!ifs) { all_exist = false; break; }
215+
if (!ifs) { all_exist = 0; break; }
216216
}
217+
// Add MPI communication to synchronize all_exist across processes
218+
#ifdef __MPI
219+
// don't read in any files if one of the processes doesn't have it
220+
MPI_Allreduce(MPI_IN_PLACE, &all_exist, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
221+
#endif
217222
if (all_exist)
218223
{
219224
// Read HexxR in CSR format
@@ -228,11 +233,24 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
228233
{
229234
// Read HexxR in binary format (old version)
230235
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(GlobalV::MY_RANK);
231-
if (GlobalC::exx_info.info_ri.real_number) {
232-
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxd);
236+
std::ifstream ifs(restart_HR_path_cereal, std::ios::binary);
237+
int all_exist_cereal = ifs ? 1 : 0;
238+
#ifdef __MPI
239+
MPI_Allreduce(MPI_IN_PLACE, &all_exist_cereal, 1, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
240+
#endif
241+
if (!all_exist_cereal)
242+
{
243+
//no HexxR file in CSR or binary format
244+
this->restart = false;
233245
}
234-
else {
235-
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxc);
246+
else
247+
{
248+
if (GlobalC::exx_info.info_ri.real_number) {
249+
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxd);
250+
}
251+
else {
252+
ModuleIO::read_Hexxs_cereal(restart_HR_path_cereal, *Hexxc);
253+
}
236254
}
237255
}
238256
}

0 commit comments

Comments
 (0)