Skip to content

Commit 767e4d4

Browse files
committed
fix compile
1 parent d87aba3 commit 767e4d4

19 files changed

+170
-67
lines changed

python/pyabacus/CONTRIBUTING.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ list(APPEND _diago
190190
${HSOLVER_PATH}/diag_const_nums.cpp
191191
${HSOLVER_PATH}/diago_iter_assist.cpp
192192
${HSOLVER_PATH}/kernels/dngvd_op.cpp
193+
${HSOLVER_PATH}/kernels/bpcg_kernel_op.cpp
193194
${BASE_PATH}/kernels/math_kernel_op.cpp
194195
${BASE_PATH}/kernels/math_kernel_op_vec.cpp
195196
${BASE_PATH}/kernels/math_ylm_op.cpp

python/pyabacus/src/hsolver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ list(APPEND _diago
1010

1111

1212
${HSOLVER_PATH}/kernels/dngvd_op.cpp
13+
${HSOLVER_PATH}/kernels/bpcg_kernel_op.cpp
1314
# dependency
1415
${BASE_PATH}/kernels/math_kernel_op.cpp
1516
${BASE_PATH}/kernels/math_kernel_op_vec.cpp

source/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ list(APPEND device_srcs
3636
module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp
3737
module_basis/module_pw/kernels/pw_op.cpp
3838
module_hsolver/kernels/dngvd_op.cpp
39+
module_hsolver/kernels/bpcg_kernel_op.cpp
3940
module_elecstate/kernels/elecstate_op.cpp
4041

4142
# module_psi/kernels/psi_memory_op.cpp
@@ -65,6 +66,7 @@ if(USE_CUDA)
6566
module_hamilt_pw/hamilt_pwdft/kernels/cuda/onsite_op.cu
6667
module_basis/module_pw/kernels/cuda/pw_op.cu
6768
module_hsolver/kernels/cuda/dngvd_op.cu
69+
module_hsolver/kernels/cuda/bpcg_kernel_op.cu
6870
module_elecstate/kernels/cuda/elecstate_op.cu
6971

7072
# module_psi/kernels/cuda/memory_op.cu
@@ -91,6 +93,7 @@ if(USE_ROCM)
9193
module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu
9294
module_basis/module_pw/kernels/rocm/pw_op.hip.cu
9395
module_hsolver/kernels/rocm/dngvd_op.hip.cu
96+
module_hsolver/kernels/rocm/bpcg_kernel_op.hip.cu
9497
module_elecstate/kernels/rocm/elecstate_op.hip.cu
9598

9699
# module_psi/kernels/rocm/memory_op.hip.cu

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ OBJS_HSOLVER=diago_cg.o\
350350
hsolver_pw_sdft.o\
351351
diago_iter_assist.o\
352352
dngvd_op.o\
353+
bpcg_kernel_op.o\
353354
diag_const_nums.o\
354355
diag_hs_para.o\
355356
diago_pxxxgvx.o\

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,6 @@ const int warp_size = 32;
1616
const int thread_per_block = 256;
1717
}
1818

19-
template <>
20-
struct GetTypeReal<thrust::complex<float>> {
21-
using type = float; /**< The return type specialization for std::complex<double>. */
22-
};
23-
template <>
24-
struct GetTypeReal<thrust::complex<double>> {
25-
using type = double; /**< The return type specialization for std::complex<double>. */
26-
};
2719
namespace ModuleBase {
2820
template <typename T>
2921
struct GetTypeThrust {
@@ -42,16 +34,6 @@ struct GetTypeThrust<std::complex<double>> {
4234

4335
static cublasHandle_t cublas_handle = nullptr;
4436

45-
static inline
46-
void xdot_wrapper(const int &n, const float * x, const int &incx, const float * y, const int &incy, float &result) {
47-
cublasErrcheck(cublasSdot(cublas_handle, n, x, incx, y, incy, &result));
48-
}
49-
50-
static inline
51-
void xdot_wrapper(const int &n, const double * x, const int &incx, const double * y, const int &incy, double &result) {
52-
cublasErrcheck(cublasDdot(cublas_handle, n, x, incx, y, incy, &result));
53-
}
54-
5537
void createGpuBlasHandle(){
5638
if (cublas_handle == nullptr) {
5739
cublasErrcheck(cublasCreate(&cublas_handle));

source/module_base/kernels/cuda/math_kernel_op_vec.cu

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
11
#include "module_base/kernels/math_kernel_op.h"
22

3+
#include <base/macros/macros.h>
34
#include <thrust/complex.h>
45

6+
template <>
7+
struct GetTypeReal<thrust::complex<float>> {
8+
using type = float; /**< The return type specialization for std::complex<double>. */
9+
};
10+
template <>
11+
struct GetTypeReal<thrust::complex<double>> {
12+
using type = double; /**< The return type specialization for std::complex<double>. */
13+
};
514
namespace ModuleBase
615
{
16+
const int thread_per_block = 256;
17+
static cublasHandle_t cublas_handle = nullptr;
18+
19+
static inline
20+
void xdot_wrapper(const int &n, const float * x, const int &incx, const float * y, const int &incy, float &result) {
21+
cublasErrcheck(cublasSdot(cublas_handle, n, x, incx, y, incy, &result));
22+
}
23+
24+
static inline
25+
void xdot_wrapper(const int &n, const double * x, const int &incx, const double * y, const int &incy, double &result) {
26+
cublasErrcheck(cublasDdot(cublas_handle, n, x, incx, y, incy, &result));
27+
}
728

829
// Define the CUDA kernel:
9-
template <typename FPTYPE>
30+
template <typename T>
1031
__global__ void vector_mul_real_kernel(const int size,
11-
thrust::complex<FPTYPE>* result,
12-
const thrust::complex<FPTYPE>* vector,
13-
const FPTYPE constant)
32+
T* result,
33+
const T* vector,
34+
const typename GetTypeReal<T>::type constant)
1435
{
1536
int i = blockIdx.x * blockDim.x + threadIdx.x;
1637
if (i < size)
@@ -87,6 +108,20 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
87108
}
88109

89110
// vector operator: result[i] = vector[i] * constant
111+
template <>
112+
void vector_mul_real_op<double, base_device::DEVICE_GPU>::operator()(const int dim,
113+
double* result,
114+
const double* vector,
115+
const double constant)
116+
{
117+
// In small cases, 1024 threads per block will only utilize 17 blocks, much less than 40
118+
int thread = thread_per_block;
119+
int block = (dim + thread - 1) / thread;
120+
vector_mul_real_kernel<double><<<block, thread>>>(dim, result, vector, constant);
121+
122+
cudaCheckOnDebug();
123+
}
124+
90125
template <typename FPTYPE>
91126
inline void vector_mul_real_wrapper(const int dim,
92127
std::complex<FPTYPE>* result,
@@ -98,7 +133,7 @@ inline void vector_mul_real_wrapper(const int dim,
98133

99134
int thread = thread_per_block;
100135
int block = (dim + thread - 1) / thread;
101-
vector_mul_real_kernel<FPTYPE><<<block, thread>>>(dim, result_tmp, vector_tmp, constant);
136+
vector_mul_real_kernel<thrust::complex<FPTYPE>><<<block, thread>>>(dim, result_tmp, vector_tmp, constant);
102137

103138
cudaCheckOnDebug();
104139
}
@@ -326,4 +361,25 @@ double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(co
326361
return dot_complex_wrapper(dim, psi_L, psi_R, reduce);
327362
}
328363

364+
// Explicitly instantiate functors for the types of functor registered.
365+
template struct vector_mul_real_op<std::complex<float>, base_device::DEVICE_GPU>;
366+
template struct vector_mul_real_op<double, base_device::DEVICE_GPU>;
367+
template struct vector_mul_real_op<std::complex<double>, base_device::DEVICE_GPU>;
368+
369+
template struct vector_mul_vector_op<float, base_device::DEVICE_GPU>;
370+
template struct vector_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
371+
template struct vector_mul_vector_op<double, base_device::DEVICE_GPU>;
372+
template struct vector_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
373+
template struct vector_div_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
374+
template struct vector_div_vector_op<double, base_device::DEVICE_GPU>;
375+
template struct vector_div_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
376+
377+
template struct constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>;
378+
template struct constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>;
379+
template struct constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>;
380+
template struct constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>;
381+
382+
template struct dot_real_op<std::complex<float>, base_device::DEVICE_GPU>;
383+
template struct dot_real_op<double, base_device::DEVICE_GPU>;
384+
template struct dot_real_op<std::complex<double>, base_device::DEVICE_GPU>;
329385
} // namespace ModuleBase

source/module_base/kernels/math_kernel_op.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ template <typename FPTYPE, typename Device> struct scal_op {
6666
const int &incx);
6767
};
6868

69-
// vector operator: result[i] = vector[i] * constant
70-
template <typename FPTYPE, typename Device> struct vector_mul_real_op {
71-
/// @brief result[i] = vector[i] * constant, where vector is complex number and constant is real number
69+
template <typename T, typename Device> struct vector_mul_real_op {
70+
using Real = typename GetTypeReal<T>::type;
71+
/// @brief result[i] = vector[i] * constant, where vector is complex number and constant is real number。
72+
/// It is different from the scal_op, which is used to multiply a complex number by a complex number.
7273
///
7374
/// Input Parameters
7475
/// \param dim : array size
@@ -78,8 +79,7 @@ template <typename FPTYPE, typename Device> struct vector_mul_real_op {
7879
/// Output Parameters
7980
/// \param result : output array
8081
/// \note Use mulitple instead of divide. It is faster.
81-
void operator()(const int dim, std::complex<FPTYPE> *result, const std::complex<FPTYPE> *vector,
82-
const FPTYPE constant);
82+
void operator()(const int dim, T* result, const T* vector, const Real constant);
8383
};
8484

8585
// vector operator: result[i] = vector1[i](complex) * vector2[i](not complex)
@@ -293,13 +293,11 @@ template <typename T> struct dot_real_op<T, base_device::DEVICE_GPU> {
293293
};
294294

295295
// vector operator: result[i] = vector[i] / constant
296-
template <typename FPTYPE>
297-
struct vector_mul_real_op<FPTYPE, base_device::DEVICE_GPU>
296+
template <typename T>
297+
struct vector_mul_real_op<T, base_device::DEVICE_GPU>
298298
{
299-
void operator()(const int dim,
300-
std::complex<FPTYPE>* result,
301-
const std::complex<FPTYPE>* vector,
302-
const FPTYPE constant);
299+
using Real = typename GetTypeReal<T>::type;
300+
void operator()(const int dim, T* result, const T* vector, const Real constant);
303301
};
304302

305303
// vector operator: result[i] = vector1[i](complex) * vector2[i](not complex)

source/module_base/kernels/math_kernel_op_vec.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@ struct scal_op<FPTYPE, base_device::DEVICE_CPU>
1515
}
1616
};
1717

18-
template <typename FPTYPE>
19-
struct vector_mul_real_op<FPTYPE, base_device::DEVICE_CPU>
18+
template <typename T>
19+
struct vector_mul_real_op<T, base_device::DEVICE_CPU>
2020
{
21-
void operator()(const int dim,
22-
std::complex<FPTYPE>* result,
23-
const std::complex<FPTYPE>* vector,
24-
const FPTYPE constant)
21+
using Real = typename GetTypeReal<T>::type;
22+
void operator()(const int dim, T* result, const T* vector, const Real constant)
2523
{
2624
#ifdef _OPENMP
2725
#pragma omp parallel for schedule(static, 4096 / sizeof(Real))
@@ -153,8 +151,9 @@ struct dot_real_op<std::complex<FPTYPE>, base_device::DEVICE_CPU>
153151
template struct scal_op<float, base_device::DEVICE_CPU>;
154152
template struct scal_op<double, base_device::DEVICE_CPU>;
155153

156-
template struct vector_mul_real_op<float, base_device::DEVICE_CPU>;
154+
template struct vector_mul_real_op<std::complex<float>, base_device::DEVICE_CPU>;
157155
template struct vector_mul_real_op<double, base_device::DEVICE_CPU>;
156+
template struct vector_mul_real_op<std::complex<double>, base_device::DEVICE_CPU>;
158157

159158
template struct vector_mul_vector_op<std::complex<float>, base_device::DEVICE_CPU>;
160159
template struct vector_mul_vector_op<double, base_device::DEVICE_CPU>;

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,6 @@ struct GetTypeThrust<std::complex<double>> {
3939

4040
static hipblasHandle_t cublas_handle = nullptr;
4141

42-
static inline
43-
void xdot_wrapper(const int &n, const float * x, const int &incx, const float * y, const int &incy, float &result) {
44-
hipblasErrcheck(hipblasSdot(cublas_handle, n, x, incx, y, incy, &result));
45-
}
46-
47-
static inline
48-
void xdot_wrapper(const int &n, const double * x, const int &incx, const double * y, const int &incy, double &result) {
49-
hipblasErrcheck(hipblasDdot(cublas_handle, n, x, incx, y, incy, &result));
50-
}
51-
5242
void createGpuBlasHandle(){
5343
if (cublas_handle == nullptr) {
5444
hipblasErrcheck(hipblasCreate(&cublas_handle));

source/module_base/kernels/rocm/math_kernel_op_vec.hip.cu

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
11
#include "module_base/kernels/math_kernel_op.h"
22

3+
#include <base/macros/macros.h>
34
#include <thrust/complex.h>
4-
5+
template <>
6+
struct GetTypeReal<thrust::complex<float>> {
7+
using type = float; /**< The return type specialization for std::complex<double>. */
8+
};
9+
template <>
10+
struct GetTypeReal<thrust::complex<double>> {
11+
using type = double; /**< The return type specialization for std::complex<double>. */
12+
};
513
namespace ModuleBase
614
{
15+
16+
static hipblasHandle_t cublas_handle = nullptr;
17+
static inline
18+
void xdot_wrapper(const int &n, const float * x, const int &incx, const float * y, const int &incy, float &result) {
19+
hipblasErrcheck(hipblasSdot(cublas_handle, n, x, incx, y, incy, &result));
20+
}
21+
22+
static inline
23+
void xdot_wrapper(const int &n, const double * x, const int &incx, const double * y, const int &incy, double &result) {
24+
hipblasErrcheck(hipblasDdot(cublas_handle, n, x, incx, y, incy, &result));
25+
}
26+
727
// Define the CUDA kernel:
8-
template <typename FPTYPE>
28+
template <typename T>
929
__launch_bounds__(1024) __global__ void vector_mul_real_kernel(const int size,
10-
thrust::complex<FPTYPE>* result,
11-
const thrust::complex<FPTYPE>* vector,
12-
FPTYPE constant)
30+
T* result,
31+
const T* vector,
32+
const typename GetTypeReal<T>::type constant)
1333
{
1434
int i = blockIdx.x * blockDim.x + threadIdx.x;
1535
if (i < size)
@@ -86,6 +106,26 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
86106
}
87107

88108
// vector operator: result[i] = vector[i] * constant
109+
template <>
110+
void vector_mul_real_op<double, base_device::DEVICE_GPU>::operator()(const int dim,
111+
double* result,
112+
const double* vector,
113+
const double constant)
114+
{
115+
int thread = 1024;
116+
int block = (dim + thread - 1) / thread;
117+
hipLaunchKernelGGL(HIP_KERNEL_NAME(vector_div_constant_kernel<double>),
118+
dim3(block),
119+
dim3(thread),
120+
0,
121+
0,
122+
dim,
123+
result,
124+
vector,
125+
constant);
126+
127+
hipCheckOnDebug();
128+
}
89129
template <typename FPTYPE>
90130
inline void vector_mul_real_wrapper(const int dim,
91131
std::complex<FPTYPE>* result,
@@ -96,7 +136,7 @@ inline void vector_mul_real_wrapper(const int dim,
96136
const thrust::complex<FPTYPE>* vector_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector);
97137
int thread = 1024;
98138
int block = (dim + thread - 1) / thread;
99-
hipLaunchKernelGGL(HIP_KERNEL_NAME(vector_mul_real_kernel<FPTYPE>),
139+
hipLaunchKernelGGL(HIP_KERNEL_NAME(vector_mul_real_kernel<thrust::complex<FPTYPE>>),
100140
dim3(block),
101141
dim3(thread),
102142
0,
@@ -378,4 +418,26 @@ double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(co
378418
{
379419
return dot_complex_wrapper(dim, psi_L, psi_R, reduce);
380420
}
421+
422+
// Explicitly instantiate functors for the types of functor registered.
423+
template struct vector_mul_real_op<std::complex<float>, base_device::DEVICE_GPU>;
424+
template struct vector_mul_real_op<double, base_device::DEVICE_GPU>;
425+
template struct vector_mul_real_op<std::complex<double>, base_device::DEVICE_GPU>;
426+
427+
template struct vector_mul_vector_op<float, base_device::DEVICE_GPU>;
428+
template struct vector_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
429+
template struct vector_mul_vector_op<double, base_device::DEVICE_GPU>;
430+
template struct vector_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
431+
template struct vector_div_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
432+
template struct vector_div_vector_op<double, base_device::DEVICE_GPU>;
433+
template struct vector_div_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
434+
435+
template struct constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>;
436+
template struct constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>;
437+
template struct constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>;
438+
template struct constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>;
439+
440+
template struct dot_real_op<std::complex<float>, base_device::DEVICE_GPU>;
441+
template struct dot_real_op<double, base_device::DEVICE_GPU>;
442+
template struct dot_real_op<std::complex<double>, base_device::DEVICE_GPU>;
381443
} // namespace ModuleBase

0 commit comments

Comments
 (0)