Skip to content

Commit 541dde2

Browse files
committed
update the psi.h
1 parent 6e6855f commit 541dde2

File tree

14 files changed

+135
-126
lines changed

14 files changed

+135
-126
lines changed
Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11

2-
#include "fftw3.h"
3-
#if defined(__FFTW3_MPI) && defined(__MPI)
4-
#include <fftw3-mpi.h>
5-
//#include "fftw3-mpi_mkl.h"
6-
#endif
72

83
#if defined(__CUDA) || defined(__UT_USE_CUDA)
9-
#include "cufft.h"
4+
// #include "cufft.h"
105
#include "cuda_runtime.h"
116
#endif
127

@@ -17,6 +12,3 @@
1712

1813

1914
#include "module_psi/psi.h"
20-
21-
22-

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#include <complex>
2-
#include <string>
3-
#include "fftw3.h"
41
#ifndef FFT_BASE_H
52
#define FFT_BASE_H
3+
4+
#include <complex>
5+
#include <string>
66
namespace ModulePW
77
{
88
template <typename FPTYPE>

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
#include "fft_base.h"
2-
#include <memory>
3-
// #include "module_psi/psi.h"
41
#ifndef FFT_TEMP_H
52
#define FFT_TEMP_H
3+
4+
#include "fft_base.h"
5+
#include <memory>
6+
#include "fft_cpu.h"
7+
#ifdef __CUDA
8+
#include "fft_cuda.h"
9+
#endif
10+
#ifdef __ROCM
11+
#include "fft_rocm.h"
12+
#endif
613
namespace ModulePW
714
{
815
class FFT_Bundle

source/module_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "fft_cuda.h"
22
#include "module_base/module_device/memory_op.h"
3-
#include "module_hamilt_pw/hamilt_pwdft/global.h"
3+
44
namespace ModulePW
55
{
66
template <typename FPTYPE>

source/module_basis/module_pw/module_fft/fft_cuda.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#include "fft_base.h"
2-
#include "cufft.h"
3-
#include "cuda_runtime.h"
4-
51
#ifndef FFT_CUDA_H
62
#define FFT_CUDA_H
3+
4+
#include "fft_base.h"
5+
#include "kernel/fft_cuda_func.h"
6+
77
namespace ModulePW
88
{
99
template <typename FPTYPE>

source/module_basis/module_pw/module_fft/fft_rcom.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
#include "fft_base.h"
2-
#include <hipfft/hipfft.h>
3-
#include <hip/hip_runtime.h>
41
#ifndef FFT_ROCM_H
52
#define FFT_ROCM_H
3+
4+
#include "fft_base.h"
5+
#include "kernel/fft_rcom_func.h"
66
namespace ModulePW
77
{
88
template <typename FPTYPE>
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#ifndef FFT_CUDA_FUNC_H
2+
#define FFT_CUDA_FUNC_H
3+
#include "cufft.h"
4+
#include "cuda_runtime.h"
5+
6+
static const char* _cufftGetErrorString(cufftResult_t error)
7+
{
8+
switch (error)
9+
{
10+
case CUFFT_SUCCESS:
11+
return "CUFFT_SUCCESS";
12+
case CUFFT_INVALID_PLAN:
13+
return "CUFFT_INVALID_PLAN";
14+
case CUFFT_ALLOC_FAILED:
15+
return "CUFFT_ALLOC_FAILED";
16+
case CUFFT_INVALID_TYPE:
17+
return "CUFFT_INVALID_TYPE";
18+
case CUFFT_INVALID_VALUE:
19+
return "CUFFT_INVALID_VALUE";
20+
case CUFFT_INTERNAL_ERROR:
21+
return "CUFFT_INTERNAL_ERROR";
22+
case CUFFT_EXEC_FAILED:
23+
return "CUFFT_EXEC_FAILED";
24+
case CUFFT_SETUP_FAILED:
25+
return "CUFFT_SETUP_FAILED";
26+
case CUFFT_INVALID_SIZE:
27+
return "CUFFT_INVALID_SIZE";
28+
case CUFFT_UNALIGNED_DATA:
29+
return "CUFFT_UNALIGNED_DATA";
30+
case CUFFT_INCOMPLETE_PARAMETER_LIST:
31+
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
32+
case CUFFT_INVALID_DEVICE:
33+
return "CUFFT_INVALID_DEVICE";
34+
case CUFFT_PARSE_ERROR:
35+
return "CUFFT_PARSE_ERROR";
36+
case CUFFT_NO_WORKSPACE:
37+
return "CUFFT_NO_WORKSPACE";
38+
case CUFFT_NOT_IMPLEMENTED:
39+
return "CUFFT_NOT_IMPLEMENTED";
40+
case CUFFT_LICENSE_ERROR:
41+
return "CUFFT_LICENSE_ERROR";
42+
case CUFFT_NOT_SUPPORTED:
43+
return "CUFFT_NOT_SUPPORTED";
44+
}
45+
return "<unknown>";
46+
}
47+
48+
#define CHECK_CUFFT(func) \
49+
{ \
50+
cufftResult_t status = (func); \
51+
if (status != CUFFT_SUCCESS) \
52+
{ \
53+
printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
54+
_cufftGetErrorString(status), status); \
55+
} \
56+
}
57+
#endif // FFT_CUDA_FUNC_H
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#ifndef FFT_ROCM_FUNC_H
2+
#define FFT_ROCM_FUNC_H
3+
#include <hipfft/hipfft.h>
4+
#include <hip/hip_runtime.h>
5+
static const char* _hipfftGetErrorString(hipfftResult_t error)
6+
{
7+
switch (error)
8+
{
9+
case HIPFFT_SUCCESS:
10+
return "HIPFFT_SUCCESS";
11+
case HIPFFT_INVALID_PLAN:
12+
return "HIPFFT_INVALID_PLAN";
13+
case HIPFFT_ALLOC_FAILED:
14+
return "HIPFFT_ALLOC_FAILED";
15+
case HIPFFT_INVALID_TYPE:
16+
return "HIPFFT_INVALID_TYPE";
17+
case HIPFFT_INVALID_VALUE:
18+
return "HIPFFT_INVALID_VALUE";
19+
case HIPFFT_INTERNAL_ERROR:
20+
return "HIPFFT_INTERNAL_ERROR";
21+
case HIPFFT_EXEC_FAILED:
22+
return "HIPFFT_EXEC_FAILED";
23+
case HIPFFT_SETUP_FAILED:
24+
return "HIPFFT_SETUP_FAILED";
25+
case HIPFFT_INVALID_SIZE:
26+
return "HIPFFT_INVALID_SIZE";
27+
case HIPFFT_UNALIGNED_DATA:
28+
return "HIPFFT_UNALIGNED_DATA";
29+
case HIPFFT_INCOMPLETE_PARAMETER_LIST:
30+
return "HIPFFT_INCOMPLETE_PARAMETER_LIST";
31+
case HIPFFT_INVALID_DEVICE:
32+
return "HIPFFT_INVALID_DEVICE";
33+
case HIPFFT_PARSE_ERROR:
34+
return "HIPFFT_PARSE_ERROR";
35+
case HIPFFT_NO_WORKSPACE:
36+
return "HIPFFT_NO_WORKSPACE";
37+
case HIPFFT_NOT_IMPLEMENTED:
38+
return "HIPFFT_NOT_IMPLEMENTED";
39+
case HIPFFT_NOT_SUPPORTED:
40+
return "HIPFFT_NOT_SUPPORTED";
41+
}
42+
return "<unknown>";
43+
}
44+
#define CHECK_CUFFT(func) \
45+
{ \
46+
hipfftResult_t status = (func); \
47+
if (status != HIPFFT_SUCCESS) \
48+
{ \
49+
printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
50+
_hipfftGetErrorString(status), status); \
51+
} \
52+
}
53+
#endif // FFT_ROCM_FUNC_H

source/module_basis/module_pw/pw_basis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#ifndef PWBASIS_H
22
#define PWBASIS_H
33

4+
#include "module_base/module_device/memory_op.h"
45
#include "module_base/matrix.h"
56
#include "module_base/matrix3.h"
67
#include "module_base/vector3.h"
78
#include <complex>
8-
#include "fft.h"
99
#include "module_fft/fft_bundle.h"
1010
#include <cstring>
1111
#ifdef __MPI

source/module_basis/module_pw/pw_transform.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "fft.h"
21
#include "module_fft/fft_bundle.h"
32
#include <complex>
43
#include "pw_basis.h"

0 commit comments

Comments
 (0)