Skip to content

Commit 7ca986b

Browse files
committed
revert check_func
1 parent 128283a commit 7ca986b

File tree

10 files changed

+119
-130
lines changed

10 files changed

+119
-130
lines changed

source/module_basis/module_pw/module_fft/fft_base.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef FFT_BASE_H
22
#define FFT_BASE_H
3-
43
#include <complex>
54
#include <string>
6-
#include "module_base/module_device/memory_op.h"
75
namespace ModulePW
86
{
97
template <typename FPTYPE>

source/module_basis/module_pw/module_fft/fft_bundle.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
#include <cassert>
22
#include "fft_bundle.h"
3-
#if defined(__CUDA)
4-
#include "fft_cuda.h"
5-
#endif
6-
#if defined(__ROCM)
7-
#include "fft_rcom.h"
8-
#endif
3+
94

105
template<typename FFT_BASE, typename... Args>
116
std::unique_ptr<FFT_BASE> make_unique(Args &&... args)

source/module_basis/module_pw/module_fft/fft_bundle.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
#ifndef FFT_TEMP_H
22
#define FFT_TEMP_H
33

4+
#include <memory>
45
#include "fft_base.h"
56
#include "fft_cpu.h"
67
#include "module_base/module_device/device.h"
78
#include "module_base/module_device/memory_op.h"
8-
#include <memory>
9+
#if defined(__CUDA)
10+
#include "fft_cuda.h"
11+
#endif
12+
#if defined(__ROCM)
13+
#include "fft_rcom.h"
14+
#endif
915

1016
namespace ModulePW
1117
{

source/module_basis/module_pw/module_fft/fft_cuda.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "fft_cuda.h"
22
#include "module_base/module_device/memory_op.h"
3+
#include "module_hamilt_pw/hamilt_pwdft/global.h"
34

45
namespace ModulePW
56
{
@@ -105,4 +106,9 @@ template <> std::complex<float>*
105106
FFT_CUDA<float>::get_auxr_3d_data() const {return this->c_auxr_3d;}
106107
template <> std::complex<double>*
107108
FFT_CUDA<double>::get_auxr_3d_data() const {return this->z_auxr_3d;}
109+
110+
template FFT_CUDA<float>::FFT_CUDA();
111+
template FFT_CUDA<float>::~FFT_CUDA();
112+
template FFT_CUDA<double>::FFT_CUDA();
113+
template FFT_CUDA<double>::~FFT_CUDA();
108114
}// namespace ModulePW

source/module_basis/module_pw/module_fft/fft_cuda.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#ifndef FFT_CUDA_H
22
#define FFT_CUDA_H
3-
43
#include "fft_base.h"
5-
#include "kernel/fft_cuda_func.h"
6-
4+
#include "cufft.h"
5+
#include "cuda_runtime.h"
76
namespace ModulePW
87
{
98
template <typename FPTYPE>
@@ -62,9 +61,6 @@ class FFT_CUDA : public FFT_BASE<FPTYPE>
6261
std::complex<double>* z_auxr_3d = nullptr; // fft space
6362

6463
};
65-
template FFT_CUDA<float>::FFT_CUDA();
66-
template FFT_CUDA<float>::~FFT_CUDA();
67-
template FFT_CUDA<double>::FFT_CUDA();
68-
template FFT_CUDA<double>::~FFT_CUDA();
64+
6965
} // namespace ModulePW
7066
#endif

source/module_basis/module_pw/module_fft/fft_rcom.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
12
#ifndef FFT_ROCM_H
23
#define FFT_ROCM_H
3-
44
#include "fft_base.h"
5-
#include "kernel/fft_rcom_func.h"
5+
#include <hipfft/hipfft.h>
6+
#include <hip/hip_runtime.h>
67
namespace ModulePW
78
{
89
template <typename FPTYPE>

source/module_basis/module_pw/module_fft/kernel/fft_cuda_func.h

Lines changed: 0 additions & 57 deletions
This file was deleted.

source/module_basis/module_pw/module_fft/kernel/fft_rcom_func.h

Lines changed: 0 additions & 53 deletions
This file was deleted.

source/module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define GEMM_SELECTOR_H
33

44
#include "module_cell/unitcell.h"
5-
#include "cuda_runtime.h"
5+
66
typedef std::function<void(int,
77
int,
88
int*,

source/module_hamilt_pw/hamilt_pwdft/global.h

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,47 @@ static const char* _cublasGetErrorString(cublasStatus_t error)
3939
return "<unknown>";
4040
}
4141

42+
static const char* _cufftGetErrorString(cufftResult_t error)
43+
{
44+
switch (error)
45+
{
46+
case CUFFT_SUCCESS:
47+
return "CUFFT_SUCCESS";
48+
case CUFFT_INVALID_PLAN:
49+
return "CUFFT_INVALID_PLAN";
50+
case CUFFT_ALLOC_FAILED:
51+
return "CUFFT_ALLOC_FAILED";
52+
case CUFFT_INVALID_TYPE:
53+
return "CUFFT_INVALID_TYPE";
54+
case CUFFT_INVALID_VALUE:
55+
return "CUFFT_INVALID_VALUE";
56+
case CUFFT_INTERNAL_ERROR:
57+
return "CUFFT_INTERNAL_ERROR";
58+
case CUFFT_EXEC_FAILED:
59+
return "CUFFT_EXEC_FAILED";
60+
case CUFFT_SETUP_FAILED:
61+
return "CUFFT_SETUP_FAILED";
62+
case CUFFT_INVALID_SIZE:
63+
return "CUFFT_INVALID_SIZE";
64+
case CUFFT_UNALIGNED_DATA:
65+
return "CUFFT_UNALIGNED_DATA";
66+
case CUFFT_INCOMPLETE_PARAMETER_LIST:
67+
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
68+
case CUFFT_INVALID_DEVICE:
69+
return "CUFFT_INVALID_DEVICE";
70+
case CUFFT_PARSE_ERROR:
71+
return "CUFFT_PARSE_ERROR";
72+
case CUFFT_NO_WORKSPACE:
73+
return "CUFFT_NO_WORKSPACE";
74+
case CUFFT_NOT_IMPLEMENTED:
75+
return "CUFFT_NOT_IMPLEMENTED";
76+
case CUFFT_LICENSE_ERROR:
77+
return "CUFFT_LICENSE_ERROR";
78+
case CUFFT_NOT_SUPPORTED:
79+
return "CUFFT_NOT_SUPPORTED";
80+
}
81+
return "<unknown>";
82+
}
4283

4384
#define CHECK_CUDA(func) \
4485
{ \
@@ -70,7 +111,15 @@ static const char* _cublasGetErrorString(cublasStatus_t error)
70111
} \
71112
}
72113

73-
114+
#define CHECK_CUFFT(func) \
115+
{ \
116+
cufftResult_t status = (func); \
117+
if (status != CUFFT_SUCCESS) \
118+
{ \
119+
printf("In File %s : CUFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
120+
_cufftGetErrorString(status), status); \
121+
} \
122+
}
74123
#endif // __CUDA
75124

76125
#ifdef __ROCM
@@ -118,6 +167,45 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error)
118167
// return "<unknown>";
119168
// }
120169

170+
static const char* _hipfftGetErrorString(hipfftResult_t error)
171+
{
172+
switch (error)
173+
{
174+
case HIPFFT_SUCCESS:
175+
return "HIPFFT_SUCCESS";
176+
case HIPFFT_INVALID_PLAN:
177+
return "HIPFFT_INVALID_PLAN";
178+
case HIPFFT_ALLOC_FAILED:
179+
return "HIPFFT_ALLOC_FAILED";
180+
case HIPFFT_INVALID_TYPE:
181+
return "HIPFFT_INVALID_TYPE";
182+
case HIPFFT_INVALID_VALUE:
183+
return "HIPFFT_INVALID_VALUE";
184+
case HIPFFT_INTERNAL_ERROR:
185+
return "HIPFFT_INTERNAL_ERROR";
186+
case HIPFFT_EXEC_FAILED:
187+
return "HIPFFT_EXEC_FAILED";
188+
case HIPFFT_SETUP_FAILED:
189+
return "HIPFFT_SETUP_FAILED";
190+
case HIPFFT_INVALID_SIZE:
191+
return "HIPFFT_INVALID_SIZE";
192+
case HIPFFT_UNALIGNED_DATA:
193+
return "HIPFFT_UNALIGNED_DATA";
194+
case HIPFFT_INCOMPLETE_PARAMETER_LIST:
195+
return "HIPFFT_INCOMPLETE_PARAMETER_LIST";
196+
case HIPFFT_INVALID_DEVICE:
197+
return "HIPFFT_INVALID_DEVICE";
198+
case HIPFFT_PARSE_ERROR:
199+
return "HIPFFT_PARSE_ERROR";
200+
case HIPFFT_NO_WORKSPACE:
201+
return "HIPFFT_NO_WORKSPACE";
202+
case HIPFFT_NOT_IMPLEMENTED:
203+
return "HIPFFT_NOT_IMPLEMENTED";
204+
case HIPFFT_NOT_SUPPORTED:
205+
return "HIPFFT_NOT_SUPPORTED";
206+
}
207+
return "<unknown>";
208+
}
121209

122210
#define CHECK_CUDA(func) \
123211
{ \
@@ -149,6 +237,15 @@ static const char* _hipblasGetErrorString(hipblasStatus_t error)
149237
// }\
150238
// }
151239

240+
#define CHECK_CUFFT(func) \
241+
{ \
242+
hipfftResult_t status = (func); \
243+
if (status != HIPFFT_SUCCESS) \
244+
{ \
245+
printf("In File %s : HIPFFT API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
246+
_hipfftGetErrorString(status), status); \
247+
} \
248+
}
152249
#endif // __ROCM
153250

154251
//==========================================================

0 commit comments

Comments
 (0)