@@ -108,7 +108,7 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
108108 }
109109 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
110110#ifdef __CUDA
111- cublasErrcheck (cublasSscal (cublas_handle, n, (float2*) &alpha, (float2*) X, incX));
111+ cublasErrcheck (cublasSscal (cublas_handle, n, &alpha, X, incX));
112112#endif
113113 }
114114}
@@ -120,7 +120,7 @@ void BlasConnector::scal( const int n, const double alpha, double *X, const int
120120 }
121121 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
122122#ifdef __CUDA
123- cublasErrcheck (cublasDscal (cublas_handle, n, (double2*) &alpha, (double2*) X, incX));
123+ cublasErrcheck (cublasDscal (cublas_handle, n, &alpha, X, incX));
124124#endif
125125 }
126126}
@@ -202,7 +202,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
202202#ifdef __CUDA
203203 cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
204204 cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
205- cublasErrcheck (cublasSgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
205+ cublasErrcheck (cublasSgemm (cublas_handle, cutransA, cutransB, m, n, k, (float2*) &alpha, a, lda, b, ldb, (float2*) &beta, c, ldc));
206206#endif
207207 }
208208}
@@ -227,7 +227,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
227227#ifdef __CUDA
228228 cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
229229 cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
230- cublasErrcheck (cublasDgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
230+ cublasErrcheck (cublasDgemm (cublas_handle, cutransA, cutransB, m, n, k, (double2*) &alpha, a, lda, b, ldb, (double2*) &beta, c, ldc));
231231#endif
232232 }
233233}
@@ -252,7 +252,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
252252#ifdef __CUDA
253253 cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
254254 cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
255- cublasErrcheck (cublasCgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
255+ cublasErrcheck (cublasCgemm (cublas_handle, cutransA, cutransB, m, n, k, (float2*) &alpha, a, lda, b, ldb, (float2*) &beta, c, ldc));
256256#endif
257257 }
258258}
@@ -277,7 +277,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
277277#ifdef __CUDA
278278 cublasOperation_t cutransA = judge_trans_op (false , transa, " gemm_op" );
279279 cublasOperation_t cutransB = judge_trans_op (false , transb, " gemm_op" );
280- cublasErrcheck (cublasZgemm (cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
280+ cublasErrcheck (cublasZgemm (cublas_handle, cutransA, cutransB, m, n, k, (double2*) &alpha, a, lda, b, ldb, (double2*) &beta, c, ldc));
281281#endif
282282 }
283283}
@@ -322,7 +322,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
322322 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
323323#ifdef __CUDA
324324 cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
325- cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
325+ cublasErrcheck (cublasCgemv (cublas_handle, cutrans, m, n, (float2*) &alpha, A, lda, X, incX, (float2*) &beta, Y, incY));
326326#endif
327327 }
328328}
@@ -337,7 +337,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
337337 else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
338338#ifdef __CUDA
339339 cublasOperation_t cutrans = judge_trans_op (false , trans, " gemv_op" );
340- cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
340+ cublasErrcheck (cublasZgemv (cublas_handle, cutrans, m, n, (double2*) &alpha, A, lda, X, incX, (double2*) &beta, Y, incY));
341341#endif
342342 }
343343}
0 commit comments