Skip to content

Commit 897df26

Browse files
committed
add RALL
1 parent 3e3e4c7 commit 897df26

File tree

5 files changed

+53
-31
lines changed

5 files changed

+53
-31
lines changed

CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ if(ENABLE_MPI)
269269
add_compile_definitions(__MPI)
270270
list(APPEND math_libs MPI::MPI_CXX)
271271
endif()
272-
target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})
273272

274273

275274
if (USE_DSP)
@@ -459,10 +458,8 @@ else()
459458
find_package(Lapack REQUIRED)
460459
include_directories(${FFTW3_INCLUDE_DIRS})
461460
list(APPEND math_libs FFTW3::FFTW3 LAPACK::LAPACK BLAS::BLAS)
462-
if (ENBALE_LCAO)
463461
find_package(ScaLAPACK REQUIRED)
464462
list(APPEND math_libs ScaLAPACK::ScaLAPACK)
465-
endif()
466463
if(USE_OPENMP)
467464
list(APPEND math_libs FFTW3::FFTW3_OMP)
468465
endif()

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,18 @@ class FFT_Bundle
203203
std::string device = "cpu";
204204
std::string precision = "double";
205205
};
206+
// Use RAII (Resource Acquisition Is Initialization) to
207+
// control the resources used by hthread when setting the DSP
208+
struct FFT_RALL
209+
{
210+
const FFT_Bundle& fft_;
211+
FFT_RALL(const FFT_Bundle& fft) : fft_(fft)
212+
{fft_.resource_handler(1);}
213+
~FFT_RALL()
214+
{
215+
fft_.resource_handler(0);
216+
}
217+
};
218+
206219
} // namespace ModulePW
207220
#endif // FFT_H

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,7 @@ void FFT_DSP<double>::setupFFT()
6363
template <>
6464
void FFT_DSP<double>::resource_handler(const int flag) const
6565
{
66-
switch (ResourceState)
67-
{
68-
case constant expression:
69-
/* code */
70-
break;
71-
72-
default:
73-
break;
74-
}
66+
if (flag == 0)
7567
{
7668
hthread_barrier_destroy(b_id);
7769
hthread_group_destroy(thread_id_for);
@@ -84,6 +76,8 @@ void FFT_DSP<double>::resource_handler(const int flag) const
8476
b_id = hthread_barrier_create(cluster_id);
8577
args_for[0] = b_id;
8678
args_back[0] = b_id;
79+
}else{
80+
ModuleBase::WARNING_QUIT("FFT_DSP", "Error use of fft resource handle");
8781
}
8882
}
8983
template <>

source/module_basis/module_pw/module_fft/fft_dsp.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212

1313
namespace ModulePW
1414
{
15-
enum class ResourceState
16-
{
17-
Destroyed = 0,
18-
Created = 1,
19-
INVALID
20-
}
15+
2116
template <typename FPTYPE>
2217
class FFT_DSP : public FFT_BASE<FPTYPE>
2318
{
@@ -30,7 +25,11 @@ class FFT_DSP : public FFT_BASE<FPTYPE>
3025
void clear() override;
3126

3227
void cleanFFT() override;
33-
28+
/**
29+
* @brief Control the allocation or deallocation of hthread
30+
* resource
31+
* @param flag 0: deallocate, 1: allocate
32+
*/
3433
void resource_handler(const int flag) const override;
3534
/**
3635
* @brief Initialize the fft parameters

source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,38 @@ void Veff<OperatorPW<T, Device>>::act(
5757
const int current_spin = this->isk[this->ik];
5858

5959
#ifdef __DSP
60-
// wfcpw->fft_bundle.resource_handler(1);
60+
ModulePW::FFT_RALL guard(wfcpw->fft_bundle);
61+
for (int ib = 0; ib<nbands ; ib += npol)
62+
{
63+
if (npol == 1)
64+
{
65+
wfcpw->convolution(this->ctx,
66+
this->ik,
67+
this->veff_col,
68+
tmpsi_in,
69+
this->veff+ current_spin* this->veff_col,
70+
tmhpsi,
71+
true);
72+
}else{
73+
// Should be replaced in the Convolution way.
74+
wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
75+
wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
76+
if(this->veff_col != 0)
77+
{
78+
/// denghui added at 20221109
79+
const Real* current_veff[4];
80+
for(int is = 0; is < 4; is++)
81+
{
82+
current_veff[is] = this->veff + is * this->veff_col ; // for CPU device
83+
}
84+
veff_op()(this->ctx, this->veff_col, this->porter, this->porter1, current_veff);
85+
}
86+
// FFT back to G space.
87+
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
88+
wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
89+
}
90+
}
6191
#endif
62-
// std::cout<<"the Device is "<<Device;
6392
for (int ib = 0; ib < nbands; ib += npol)
6493
{
6594
if (npol == 1)
@@ -73,13 +102,6 @@ void Veff<OperatorPW<T, Device>>::act(
73102
veff_op()(this->ctx, this->veff_col, this->porter, this->veff + current_spin * this->veff_col);
74103
}
75104
wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
76-
// wfcpw->convolution(this->ctx,
77-
// this->ik,
78-
// this->veff_col,
79-
// tmpsi_in,
80-
// this->veff+ current_spin* this->veff_col,
81-
// tmhpsi,
82-
// true);
83105
}
84106
else
85107
{
@@ -103,9 +125,6 @@ void Veff<OperatorPW<T, Device>>::act(
103125
tmhpsi += max_npw * npol;
104126
tmpsi_in += max_npw * npol;
105127
}
106-
#ifdef __DSP
107-
// wfcpw->fft_bundle.resource_handler(0);
108-
#endif
109128
ModuleBase::timer::tick("Operator", "veff_pw");
110129
}
111130

0 commit comments

Comments
 (0)