Skip to content

Commit 28cb769

Browse files
authored
fix: segment fault in GPU-Davidson (#5763)
It comes from the compatible problem between c++ compiler and nvcc
1 parent 91b0281 commit 28cb769

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

source/module_hsolver/kernels/cuda/math_kernel_op.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -760,35 +760,41 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
760760
const char& trans,
761761
const int& m,
762762
const int& n,
763-
const std::complex<float>* alpha,
763+
const std::complex<float>* alpha_in,
764764
const std::complex<float>* A,
765765
const int& lda,
766766
const std::complex<float>* X,
767767
const int& incx,
768-
const std::complex<float>* beta,
768+
const std::complex<float>* beta_in,
769769
std::complex<float>* Y,
770770
const int& incy)
771771
{
772772
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
773-
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx));
773+
cuFloatComplex alpha = make_cuFloatComplex(alpha_in->real(), alpha_in->imag());
774+
cuFloatComplex beta = make_cuFloatComplex(beta_in->real(), beta_in->imag());
775+
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta, (cuFloatComplex*)Y, incx));
774776
}
775777

776778
template <>
777779
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
778780
const char& trans,
779781
const int& m,
780782
const int& n,
781-
const std::complex<double>* alpha,
783+
const std::complex<double>* alpha_in,
782784
const std::complex<double>* A,
783785
const int& lda,
784786
const std::complex<double>* X,
785787
const int& incx,
786-
const std::complex<double>* beta,
788+
const std::complex<double>* beta_in,
787789
std::complex<double>* Y,
788790
const int& incy)
789791
{
790792
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
791-
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx));
793+
cuDoubleComplex alpha = make_cuDoubleComplex(alpha_in->real(), alpha_in->imag());
794+
cuDoubleComplex beta = make_cuDoubleComplex(beta_in->real(), beta_in->imag());
795+
// icpc and nvcc have some compatible problems
796+
// We must use cuDoubleComplex instead of converting std::complex<double>* to cuDoubleComplex*
797+
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta, (cuDoubleComplex*)Y, incx));
792798
}
793799

794800
template <>

0 commit comments

Comments
 (0)