@@ -133,6 +133,14 @@ __global__ void matrix_copy_kernel(const int n1, const int n2, const T* A, const
133133 }
134134}
135135
136+ template <typename T, typename Real>
137+ __global__ void matrix_multiply_vector_kernel (const int m, const int n, T *a, const int lda, const Real *b, const Real alpha, T *c, const int ldc){
138+ int row = blockIdx .x * blockDim .x + threadIdx .x ;
139+ int col = blockIdx .y * blockDim .y + threadIdx .y ;
140+ if (col >= n || row >= m) return ;
141+ c[col * ldc + row] = a[col * lda + row] * b[col] * alpha;
142+ }
143+
136144cublasOperation_t judge_trans_op (bool is_complex, const char & trans, const char * name)
137145{
138146 if (trans == ' N' )
@@ -147,7 +155,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
147155 {
148156 return CUBLAS_OP_C;
149157 }
150- else
158+ else
151159 {
152160 ModuleBase::WARNING_QUIT (name, std::string (" Unknown trans type " ) + trans + std::string (" !" ));
153161 }
@@ -438,10 +446,44 @@ void matrixCopy<std::complex<double>, base_device::DEVICE_GPU>::operator()(const
438446 cudaCheckOnDebug ();
439447}
440448
449+ template <>
450+ void matrix_mul_vector_op<double , base_device::DEVICE_GPU>::operator ()(const int &m, const int &n,
451+ double *a, const int &lda, const double *b, const double alpha, double *c, const int &ldc){
452+ dim3 thread (16 , 16 , 1 );
453+ dim3 block ((m + thread.x - 1 ) / thread.x , (n + thread.y - 1 ) / thread.y , 1 );
454+ matrix_multiply_vector_kernel<double , double > <<<block, thread >>> (m, n, a, lda,
455+ b, alpha, c, ldc);
456+ cudaCheckOnDebug ();
457+ }
458+
459+ template <>
460+ void matrix_mul_vector_op<std::complex <float >, base_device::DEVICE_GPU>::operator ()(const int &m, const int &n,
461+ std::complex <float > *a, const int &lda, const float *b, const float alpha, std::complex <float > *c, const int &ldc){
462+ dim3 thread (16 , 16 , 1 );
463+ dim3 block ((m + thread.x - 1 ) / thread.x , (n + thread.y - 1 ) / thread.y , 1 );
464+ matrix_multiply_vector_kernel<thrust::complex <float >, float > <<<block, thread >>> (m, n, reinterpret_cast <thrust::complex <float >*>(a), lda,
465+ b, alpha, reinterpret_cast <thrust::complex <float >*>(c), ldc);
466+ cudaCheckOnDebug ();
467+ }
468+
469+ template <>
470+ void matrix_mul_vector_op<std::complex <double >, base_device::DEVICE_GPU>::operator ()(const int &m, const int &n,
471+ std::complex <double > *a, const int &lda, const double *b, const double alpha, std::complex <double > *c, const int &ldc)
472+ {
473+ dim3 thread (16 , 16 , 1 );
474+ dim3 block ((m + thread.x - 1 ) / thread.x , (n + thread.y - 1 ) / thread.y , 1 );
475+ matrix_multiply_vector_kernel<thrust::complex <double >, double > <<<block, thread >>> (m, n, reinterpret_cast <thrust::complex <double >*>(a), lda,
476+ b, alpha, reinterpret_cast <thrust::complex <double >*>(c), ldc);
477+ cudaCheckOnDebug ();
478+ }
441479
442480// Explicitly instantiate functors for the types of functor registered.
443481
444482template struct matrixCopy <std::complex <float >, base_device::DEVICE_GPU>;
445483template struct matrixCopy <double , base_device::DEVICE_GPU>;
446484template struct matrixCopy <std::complex <double >, base_device::DEVICE_GPU>;
485+
486+ template struct matrix_mul_vector_op <std::complex <float >, base_device::DEVICE_GPU>;
487+ template struct matrix_mul_vector_op <double , base_device::DEVICE_GPU>;
488+ template struct matrix_mul_vector_op <std::complex <double >, base_device::DEVICE_GPU>;
447489} // namespace ModuleBase
0 commit comments