Skip to content

Commit 7a497a5

Browse files
committed
Fix datatype
1 parent 71ef753 commit 7a497a5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

source/module_base/blas_connector.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
108108
}
109109
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
110110
#ifdef __CUDA
111-
cublasErrcheck(cublasSscal(cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
111+
cublasErrcheck(cublasSscal(cublas_handle, n, &alpha, X, incX));
112112
#endif
113113
}
114114
}
@@ -120,7 +120,7 @@ void BlasConnector::scal( const int n, const double alpha, double *X, const int
120120
}
121121
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
122122
#ifdef __CUDA
123-
cublasErrcheck(cublasDscal(cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
123+
cublasErrcheck(cublasDscal(cublas_handle, n, &alpha, X, incX));
124124
#endif
125125
}
126126
}
@@ -202,7 +202,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
202202
#ifdef __CUDA
203203
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
204204
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
205-
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
205+
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, a, lda, b, ldb, (float2*)&beta, c, ldc));
206206
#endif
207207
}
208208
}
@@ -227,7 +227,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
227227
#ifdef __CUDA
228228
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
229229
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
230-
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
230+
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, a, lda, b, ldb, (double2*)&beta, c, ldc));
231231
#endif
232232
}
233233
}
@@ -252,7 +252,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
252252
#ifdef __CUDA
253253
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
254254
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
255-
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
255+
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, a, lda, b, ldb, (float2*)&beta, c, ldc));
256256
#endif
257257
}
258258
}
@@ -277,7 +277,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
277277
#ifdef __CUDA
278278
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
279279
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
280-
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
280+
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, a, lda, b, ldb, (double2*)&beta, c, ldc));
281281
#endif
282282
}
283283
}
@@ -322,7 +322,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
322322
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
323323
#ifdef __CUDA
324324
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
325-
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
325+
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)&alpha, A, lda, X, incX, (float2*)&beta, Y, incY));
326326
#endif
327327
}
328328
}
@@ -337,7 +337,7 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
337337
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
338338
#ifdef __CUDA
339339
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
340-
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, &alpha, A, lda, X, incX, &beta, Y, incY));
340+
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)&alpha, A, lda, X, incX, (double2*)&beta, Y, incY));
341341
#endif
342342
}
343343
}

0 commit comments

Comments
 (0)