@@ -760,35 +760,41 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
760760 const char & trans,
761761 const int & m,
762762 const int & n,
763- const std::complex <float >* alpha ,
763+ const std::complex <float >* alpha_in ,
764764 const std::complex <float >* A,
765765 const int & lda,
766766 const std::complex <float >* X,
767767 const int & incx,
768- const std::complex <float >* beta ,
768+ const std::complex <float >* beta_in ,
769769 std::complex <float >* Y,
770770 const int & incy)
771771{
772772 cublasOperation_t cutrans = judge_trans_op (true , trans, " gemv_op" );
773- cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, (float2 *)alpha, (float2 *)A, lda, (float2 *)X, incx, (float2 *)beta, (float2 *)Y, incx));
773+ cuFloatComplex alpha = make_cuFloatComplex (alpha_in->real (), alpha_in->imag ());
774+ cuFloatComplex beta = make_cuFloatComplex (beta_in->real (), beta_in->imag ());
775+ cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx));
774776}
775777
776778template <>
777779void gemv_op<std::complex <double >, base_device::DEVICE_GPU>::operator ()(const base_device::DEVICE_GPU* d,
778780 const char & trans,
779781 const int & m,
780782 const int & n,
781- const std::complex <double >* alpha ,
783+ const std::complex <double >* alpha_in ,
782784 const std::complex <double >* A,
783785 const int & lda,
784786 const std::complex <double >* X,
785787 const int & incx,
786- const std::complex <double >* beta ,
788+ const std::complex <double >* beta_in ,
787789 std::complex <double >* Y,
788790 const int & incy)
789791{
790792 cublasOperation_t cutrans = judge_trans_op (true , trans, " gemv_op" );
791- cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, (double2 *)alpha, (double2 *)A, lda, (double2 *)X, incx, (double2 *)beta, (double2 *)Y, incx));
793+ cuDoubleComplex alpha = make_cuDoubleComplex (alpha_in->real (), alpha_in->imag ());
794+ cuDoubleComplex beta = make_cuDoubleComplex (beta_in->real (), beta_in->imag ());
795+ // icpc and nvcc have some compatible problems
796+ // We must use cuDoubleComplex instead of converting std::complex<double>* to cuDoubleComplex*
797+ cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx));
792798}
793799
794800template <>
0 commit comments