Skip to content

Commit 88c25d7

Browse files
committed
change teh cmake file
1 parent e5d9214 commit 88c25d7

File tree

9 files changed

+26
-35
lines changed

9 files changed

+26
-35
lines changed

CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,8 @@ target_link_libraries(${ABACUS_BIN_NAME} ${SCALAPACK_LIBRARY_DIR})
257257

258258
if (USE_DSP)
259259
add_compile_definitions(__DSP)
260-
target_link_libraries(${ABACUS_BIN_NAME} ${MTBLAS_FFT_LIBRARY_DIR})
261260
target_link_libraries(${ABACUS_BIN_NAME} ${OMPI_LIBRARY1})
262-
include_directories(${MTBLAS_FFT_LIBRARY_DIR}/include)
261+
include_directories(${MTBLAS_FFT_DIR}/libmtblas/include)
263262
include_directories(${MT_HOST_DIR}/include)
264263
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_device.a)
265264
target_link_libraries(${ABACUS_BIN_NAME} ${MT_HOST_DIR}/hthreads/lib/libhthread_host.a)

source/module_base/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ add_library(
6666

6767
target_link_libraries(base PUBLIC container)
6868
if (USE_DSP)
69-
target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblas.a)
70-
target_link_libraries(base PUBLIC ${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtblasdev.a)
69+
target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblas.a)
70+
target_link_libraries(base PUBLIC ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtblasdev.a)
7171
endif()
7272
add_subdirectory(module_container)
7373

source/module_basis/module_pw/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ add_library(
4444

4545
if (USE_DSP)
4646
target_link_libraries(planewave PRIVATE
47-
${MTBLAS_FFT_LIBRARY_DIR}/lib/libmtfft.a)
47+
${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a)
48+
target_compile_definitions( planewave PUBLIC
49+
FFT_DAT_DIR="${MTBLAS_FFT_DIR}/datfile/mt_fft_blas.dat")
4850
endif()
4951
if(ENABLE_COVERAGE)
5052
add_coverage(planewave)

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ void FFT_DSP<double>::setupFFT()
2323
PLAN* ptr_plan_backward;
2424
INT num_thread=8;
2525
INT size;
26-
27-
hthread_dat_load(cluster_id, "/vol8/home/dptech_zyz1/develop/blasfft/mtfftblas/datfile/mt_fft_blas.dat");
26+
hthread_dat_load(cluster_id, FFT_DAT_DIR);
2827

2928
//compute the size of and malloc thread
3029
size = nx*ny*nz*2*sizeof(E);

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ PW_Basis::PW_Basis()
1515

1616
PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
1717
classname="PW_Basis";
18-
this->fft_bundle.setfft("cpu",this->precision);
18+
this->fft_bundle.setfft(this->device,this->precision);
1919
}
2020

2121
PW_Basis:: ~PW_Basis()

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ PW_Basis_K::~PW_Basis_K()
2222
delete[] igl2isz_k;
2323
delete[] igl2ig_k;
2424
delete[] gk2;
25-
#if defined(__DSP)
26-
delete[] ig2ixyz_k_cpu;
27-
#endif
2825
#if defined(__CUDA) || defined(__ROCM)
2926
if (this->device == "gpu") {
3027
if (this->precision == "single") {
@@ -148,7 +145,7 @@ void PW_Basis_K::setupIndGk()
148145

149146
//get igl2isz_k and igl2ig_k
150147
if(this->npwk_max <= 0) { return;}
151-
148+
152149
delete[] igl2isz_k; this->igl2isz_k = new int [this->nks * this->npwk_max];
153150
delete[] igl2ig_k; this->igl2ig_k = new int [this->nks * this->npwk_max];
154151
for (int ik = 0; ik < this->nks; ik++)
@@ -188,7 +185,11 @@ void PW_Basis_K::setuptransform()
188185
this->getstartgr();
189186
this->setupIndGk();
190187
this->fft_bundle.clear();
191-
this->fft_bundle.setfft("dsp",this->precision);
188+
#if defined(__DSP)
189+
this->fft_bundle.setfft("dsp",this->precision);
190+
#else
191+
this->fft_bundle.setfft("cpu",this->precision);
192+
#endif
192193
if(this->xprime){
193194
this->fft_bundle.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime);
194195
}else{
@@ -337,12 +338,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const
337338

338339
void PW_Basis_K::get_ig2ixyz_k()
339340
{
340-
if (this->device != "gpu")
341-
{
342-
//only GPU need to get ig2ixyz_k
343-
return;
344-
}
345-
int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks];
341+
// if (this->device != "gpu")
342+
// {
343+
// //only GPU need to get ig2ixyz_k
344+
// return;
345+
// }
346+
ig2ixyz_k_cpu.resize(this->npwk_max * this->nks);
346347
ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks);
347348
assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily.
348349
for(int ik = 0; ik < this->nks; ++ik)
@@ -359,10 +360,7 @@ void PW_Basis_K::get_ig2ixyz_k()
359360
}
360361
}
361362
resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks);
362-
syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
363-
#if not defined (__DSP)
364-
delete[] this->ig2ixyz_k_cpu;
365-
#endif
363+
syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu.data(), this->npwk_max * this->nks);
366364
}
367365

368366
std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis
8787
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
8888
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
8989
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz
90-
int *ig2ixyz_k_cpu = nullptr; /// [npw] map ig to ixyz,which is used in dsp fft.
90+
std::vector<int> ig2ixyz_k_cpu; /// [npw] map ig to ixyz,which is used in dsp fft.
9191
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
9292

9393
// liuyu add 2023-09-06

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/,
308308
const double factor) const
309309
{
310310
#if defined(__DSP)
311-
printf("beforce the real_to_recip\n");
312311
this->real2recip_dsp(in,out,ik,add,factor);
313312
#else
314313
this->real2recip(in, out, ik, add, factor);
@@ -334,7 +333,6 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
334333
const double factor) const
335334
{
336335
#if defined(__DSP)
337-
printf("beforce the recip_to_real\n");
338336
this->recip2real_dsp(in,out,ik,add,factor);
339337
#else
340338
this->recip2real(in, out, ik, add, factor);

source/module_basis/module_pw/pw_transform_k_dsp.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace ModulePW
3535
this->nxyz,
3636
add,
3737
factor,
38-
this->ig2ixyz_k_cpu + startig,
38+
this->ig2ixyz_k_cpu.data() + startig,
3939
auxr,
4040
out);
4141
}
@@ -49,23 +49,19 @@ namespace ModulePW
4949
assert(this->gamma_only == false);
5050
const base_device::DEVICE_CPU* ctx;
5151
const base_device::DEVICE_GPU* gpux;
52-
printf("beforce the recip2real_dsp\n");
5352
// memset the auxr of 0 in the auxr,here the len of the auxr is nxyz
5453
auto * auxr = this->fft_bundle.get_auxr_3d_data<double>();
5554
memset(auxr,0,this->nxyz*2*8);
5655

5756
const int startig = ik * this->npwk_max;
5857
const int npw_k = this->npwk[ik];
59-
printf("beforce the set_3d_fft_box_op\n");
6058
//copy the mapping form the type of stick to the 3dfft
6159
set_3d_fft_box_op<double,base_device::DEVICE_CPU>()
6260
(
63-
ctx,npw_k,this->ig2ixyz_k_cpu+startig,in,auxr
61+
ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,in,auxr
6462
);
65-
printf("beforce the fft3D_backward\n");
6663
// use 3d fft backward
6764
this->fft_bundle.fft3D_backward(gpux,auxr,auxr);
68-
printf("beforce the add\n");
6965
if(add)
7066
{
7167
const int one =1;
@@ -76,7 +72,6 @@ namespace ModulePW
7672
{
7773
memcpy(out,auxr,nrxx*2*8);
7874
}
79-
printf("after the add\n");
8075
}
8176
template <>
8277
void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
@@ -114,7 +109,7 @@ namespace ModulePW
114109
//copy the mapping form the type of stick to the 3dfft
115110
set_3d_fft_box_op<double,base_device::DEVICE_CPU>()
116111
(
117-
ctx,npw_k,this->ig2ixyz_k_cpu+startig,input,auxr
112+
ctx,npw_k,this->ig2ixyz_k_cpu.data()+startig,input,auxr
118113
);
119114

120115
// use 3d fft backward
@@ -135,7 +130,7 @@ namespace ModulePW
135130
this->nxyz,
136131
add,
137132
factor,
138-
this->ig2ixyz_k_cpu + startig,
133+
this->ig2ixyz_k_cpu.data() + startig,
139134
auxr,
140135
output);
141136
ModuleBase::timer::tick(this->classname,"convolution");

0 commit comments

Comments
 (0)