Skip to content

Commit 4a2894b

Browse files
committed
add information in map
1 parent c1dac21 commit 4a2894b

File tree

7 files changed

+118
-8
lines changed

7 files changed

+118
-8
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ endif()
256256
if (USE_DSP)
257257
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
258258
add_compile_definitions(__DSP)
259+
target_include_directories(${ABACUS_BIN_NAME} ${DIR_MTFFT_INCLUDES})
260+
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTFFT_LIBRARY})
261+
target_include_directories(${ABACUS_BIN_NAME} ${DIR_HTRERAD_INLCUDES})
259262
endif()
260263

261264
find_package(Threads REQUIRED)

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#if defined(__ROCM)
1010
#include "fft_rocm.h"
1111
#endif
12-
12+
#if defined(__DSP)
13+
#include "fft_dsp.h"
14+
#endif
1315
template<typename FFT_BASE, typename... Args>
1416
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)
1517
{
@@ -67,7 +69,7 @@ void FFT_Bundle::initfft(int nx_in,
6769
{
6870
#if defined(__DSP)
6971
if (float_flag==true)
70-
ModuleBase::WARNING_QUT("device","now dsp is not support for the float type");
72+
ModuleBase::WARNING_QUT("device","now dsp fft is not support for the float type");
7173
fft_double=make_unique<FFT_DSP<double>>();
7274
fft_double->initfft(nx_in,ny_in,nz_in);
7375
#endif

source/module_basis/module_pw/module_fft/fft_dsp.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
// #define protected public
2-
// #define private public
31
#include "fft_dsp.h"
42
#include <string.h>
53
#include <iostream>
64
#include <vector>
7-
// #undef private
8-
// #undef protected
95
namespace ModulePW
106
{
117
template<>

source/module_basis/module_pw/pw_basis_k.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ PW_Basis_K::~PW_Basis_K()
2222
delete[] igl2isz_k;
2323
delete[] igl2ig_k;
2424
delete[] gk2;
25+
#if defined(__DSP)
26+
delete[] ig2ixyz_k_cpu;
27+
#endif
2528
#if defined(__CUDA) || defined(__ROCM)
2629
if (this->device == "gpu") {
2730
if (this->precision == "single") {
@@ -357,7 +360,9 @@ void PW_Basis_K::get_ig2ixyz_k()
357360
}
358361
resmem_int_op()(ig2ixyz_k, this->npwk_max * this->nks);
359362
syncmem_int_h2d_op()(this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
360-
delete[] ig2ixyz_k_cpu;
363+
#if not defined (__DSP)
364+
delete[] this->ig2ixyz_k_cpu;
365+
#endif
361366
}
362367

363368
std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const

source/module_basis/module_pw/pw_basis_k.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class PW_Basis_K : public PW_Basis
8787
int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
8888
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
8989
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz
90-
90+
int *ig2ixyz_k_cpu = nullptr; /// [npw] map ig to ixyz,which is used in dsp fft.
9191
double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
9292

9393
// liuyu add 2023-09-06
@@ -135,6 +135,20 @@ class PW_Basis_K : public PW_Basis
135135
const int ik,
136136
const bool add = false,
137137
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
138+
#if defined(__DSP)
139+
template <typename FPTYPE>
140+
void real2recip_3d(const std::complex<FPTYPE>* in,
141+
std::complex<FPTYPE>* out,
142+
const int ik,
143+
const bool add = false,
144+
const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns)
145+
template <typename FPTYPE>
146+
void recip2real_3d(const std::complex<FPTYPE>* in,
147+
std::complex<FPTYPE>* out,
148+
const int ik,
149+
const bool add = false,
150+
const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny)
151+
#endif
138152

139153
template <typename FPTYPE, typename Device>
140154
void real_to_recip(const Device* ctx,

source/module_basis/module_pw/pw_transform_k.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/,
307307
const bool add,
308308
const double factor) const
309309
{
310+
#if defined(__DSP)
311+
this->real2recip_3d(in,out,ik,add,factor);
312+
#else
310313
this->real2recip(in, out, ik, add, factor);
314+
#endif
311315
}
312316

313317
template <>
@@ -318,7 +322,11 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
318322
const bool add,
319323
const float factor) const
320324
{
325+
#if defined(__DSP)
326+
this->recip2real_3d(in,out,add,factor);
327+
#else
321328
this->recip2real(in, out, ik, add, factor);
329+
#endif
322330
}
323331
template <>
324332
void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/,
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "module_base/timer.h"
2+
#include "module_basis/module_pw/kernels/pw_op.h"
3+
#include "pw_basis_k.h"
4+
5+
#include <cassert>
6+
#include <complex>
7+
#include <string>
8+
namespace ModulePW
9+
{
10+
template <typename FPTYPE>
11+
void PW_Basis_K::real2recip_3d(const std::complex<FPTYPE>* in,
12+
std::complex<FPTYPE>* out,
13+
const int ik,
14+
const bool add,
15+
const FPTYPE factor) const
16+
{
17+
ModuleBase::timer::tick(this->classname,"real2recip_3d");
18+
const base_device::DEVICE_CPU* ctx;
19+
const base_device::DEVICE_GPU* gpux;
20+
assert(this->gamma_only == false);
21+
auto* auxr = this->fft_bundle.get_auxr_3d_data<double>();
22+
23+
const int startig = ik * this->npwk_max;
24+
const int npw_k = this->npwk[ik];
25+
memcpy(auxr,in,this->nrxx*2*8);
26+
this->fft_bundle.fft3D_forward(gpux,
27+
auxr,
28+
auxr);
29+
set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(ctx,
30+
npw_k,
31+
this->nxyz,
32+
add,
33+
factor,
34+
this->ig2ixyz_k_cpu + startig,
35+
this->fft_bundle.get_auxr_3d_data<double>(),
36+
out);
37+
ModuleBase::timer::tick(this->classname,"real2recip_3d");
38+
}
39+
40+
template <typename FPTYPE>
41+
void PW_Basis_K::recip2real_3d(const std::complex<FPTYPE>* in,
42+
std::complex<FPTYPE>* out,
43+
const int ik,
44+
const bool add,
45+
const FPTYPE factor) const
46+
{
47+
ModuleBase::timer::tick(this->classname,"recip2real_3d");
48+
49+
assert(this->gamma_only == false);
50+
const base_device::DEVICE_CPU* ctx;
51+
const base_device::DEVICE_GPU* gpux;
52+
auto* auxr = this->fft_bundle.get_auxr_3d_data<double>();
53+
memset(auxr,0,this->nrxx*2*8);
54+
const int startig = ik * this->npwk_max;
55+
const int npw_k = this->npwk[ik];
56+
57+
set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx,
58+
npw_k,
59+
this->ig2ixyz_k_cpu + startig,
60+
in,
61+
auxr);
62+
this->fft_bundle.fft3D_backward(gpux,auxr,auxr);
63+
set_recip_to_real_output_op<double, base_device::DEVICE_CPU>()(ctx,
64+
this->nrxx,
65+
add,
66+
factor,
67+
auxr,
68+
out);
69+
ModuleBase::timer::tick(this->classname,"recip2real_3d");
70+
}
71+
72+
template void PW_Basis_K::real2recip_3d<double>(const std::complex<double>* in,
73+
std::complex<double>* out,
74+
const int ik,
75+
const bool add,
76+
const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns)
77+
template void PW_Basis_K::recip2real_3d<double>(const std::complex<double>* in,
78+
std::complex<double>* out,
79+
const int ik,
80+
const bool add,
81+
const double factor) const; // in:(nz, ns) ; out(nplane,nx*ny)
82+
}

0 commit comments

Comments
 (0)