diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_hsolver/kernels/cuda/math_kernel_op.cu index 6185433895..149b9ce389 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_hsolver/kernels/cuda/math_kernel_op.cu @@ -760,17 +760,19 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const bas const char& trans, const int& m, const int& n, - const std::complex* alpha, + const std::complex* alpha_in, const std::complex* A, const int& lda, const std::complex* X, const int& incx, - const std::complex* beta, + const std::complex* beta_in, std::complex* Y, const int& incy) { cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); - cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx)); + cuFloatComplex alpha = make_cuFloatComplex(alpha_in->real(), alpha_in->imag()); + cuFloatComplex beta = make_cuFloatComplex(beta_in->real(), beta_in->imag()); + cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx)); } template <> @@ -778,17 +780,21 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ba const char& trans, const int& m, const int& n, - const std::complex* alpha, + const std::complex* alpha_in, const std::complex* A, const int& lda, const std::complex* X, const int& incx, - const std::complex* beta, + const std::complex* beta_in, std::complex* Y, const int& incy) { cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); - cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx)); + cuDoubleComplex alpha = make_cuDoubleComplex(alpha_in->real(), alpha_in->imag()); + cuDoubleComplex beta = make_cuDoubleComplex(beta_in->real(), beta_in->imag()); + // icpc and nvcc have some compatible problems + // We must use cuDoubleComplex instead of converting std::complex* to cuDoubleComplex* + cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx)); } template <>