11#include " module_base/kernels/math_kernel_op.h"
22
3+ #include < base/macros/macros.h>
34#include < thrust/complex.h>
45
6+ template <>
7+ struct GetTypeReal <thrust::complex <float >> {
8+ using type = float ; /* *< The return type specialization for std::complex<double>. */
9+ };
10+ template <>
11+ struct GetTypeReal <thrust::complex <double >> {
12+ using type = double ; /* *< The return type specialization for std::complex<double>. */
13+ };
514namespace ModuleBase
615{
16+ const int thread_per_block = 256 ;
17+ static cublasHandle_t cublas_handle = nullptr ;
18+
19+ static inline
20+ void xdot_wrapper (const int &n, const float * x, const int &incx, const float * y, const int &incy, float &result) {
21+ cublasErrcheck (cublasSdot (cublas_handle, n, x, incx, y, incy, &result));
22+ }
23+
24+ static inline
25+ void xdot_wrapper (const int &n, const double * x, const int &incx, const double * y, const int &incy, double &result) {
26+ cublasErrcheck (cublasDdot (cublas_handle, n, x, incx, y, incy, &result));
27+ }
728
829// Define the CUDA kernel:
9- template <typename FPTYPE >
30+ template <typename T >
1031__global__ void vector_mul_real_kernel (const int size,
11- thrust:: complex <FPTYPE> * result,
12- const thrust:: complex <FPTYPE> * vector,
13- const FPTYPE constant)
32+ T * result,
33+ const T * vector,
34+ const typename GetTypeReal<T>::type constant)
1435{
1536 int i = blockIdx .x * blockDim .x + threadIdx .x ;
1637 if (i < size)
@@ -87,6 +108,20 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
87108}
88109
89110// vector operator: result[i] = vector[i] * constant
111+ template <>
112+ void vector_mul_real_op<double , base_device::DEVICE_GPU>::operator ()(const int dim,
113+ double * result,
114+ const double * vector,
115+ const double constant)
116+ {
117+ // In small cases, 1024 threads per block will only utilize 17 blocks, much less than 40
118+ int thread = thread_per_block;
119+ int block = (dim + thread - 1 ) / thread;
120+ vector_mul_real_kernel<double ><<<block, thread>>> (dim, result, vector, constant);
121+
122+ cudaCheckOnDebug ();
123+ }
124+
90125template <typename FPTYPE>
91126inline void vector_mul_real_wrapper (const int dim,
92127 std::complex <FPTYPE>* result,
@@ -98,7 +133,7 @@ inline void vector_mul_real_wrapper(const int dim,
98133
99134 int thread = thread_per_block;
100135 int block = (dim + thread - 1 ) / thread;
101- vector_mul_real_kernel<FPTYPE><<<block, thread>>> (dim, result_tmp, vector_tmp, constant);
136+ vector_mul_real_kernel<thrust:: complex < FPTYPE> ><<<block, thread>>> (dim, result_tmp, vector_tmp, constant);
102137
103138 cudaCheckOnDebug ();
104139}
@@ -326,4 +361,25 @@ double dot_real_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(co
326361 return dot_complex_wrapper (dim, psi_L, psi_R, reduce);
327362}
328363
364+ // Explicitly instantiate functors for the types of functor registered.
365+ template struct vector_mul_real_op <std::complex <float >, base_device::DEVICE_GPU>;
366+ template struct vector_mul_real_op <double , base_device::DEVICE_GPU>;
367+ template struct vector_mul_real_op <std::complex <double >, base_device::DEVICE_GPU>;
368+
369+ template struct vector_mul_vector_op <float , base_device::DEVICE_GPU>;
370+ template struct vector_mul_vector_op <std::complex <float >, base_device::DEVICE_GPU>;
371+ template struct vector_mul_vector_op <double , base_device::DEVICE_GPU>;
372+ template struct vector_mul_vector_op <std::complex <double >, base_device::DEVICE_GPU>;
373+ template struct vector_div_vector_op <std::complex <float >, base_device::DEVICE_GPU>;
374+ template struct vector_div_vector_op <double , base_device::DEVICE_GPU>;
375+ template struct vector_div_vector_op <std::complex <double >, base_device::DEVICE_GPU>;
376+
377+ template struct constantvector_addORsub_constantVector_op <float , base_device::DEVICE_GPU>;
378+ template struct constantvector_addORsub_constantVector_op <std::complex <float >, base_device::DEVICE_GPU>;
379+ template struct constantvector_addORsub_constantVector_op <double , base_device::DEVICE_GPU>;
380+ template struct constantvector_addORsub_constantVector_op <std::complex <double >, base_device::DEVICE_GPU>;
381+
382+ template struct dot_real_op <std::complex <float >, base_device::DEVICE_GPU>;
383+ template struct dot_real_op <double , base_device::DEVICE_GPU>;
384+ template struct dot_real_op <std::complex <double >, base_device::DEVICE_GPU>;
329385} // namespace ModuleBase
0 commit comments