Skip to content

Commit 2770d93

Browse files
committed
Improve: Raise cuBLAS errors
1 parent 898f571 commit 2770d93

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

less_slow.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3051,6 +3051,12 @@ class unified_array {
30513051
template <typename>
30523052
struct dependent_false : std::false_type {};
30533053

3054+
void cublas_check(cublasStatus_t status) {
3055+
if (status == CUBLAS_STATUS_SUCCESS) return;
3056+
throw std::runtime_error(std::string("cuBLAS error: ") + cublasGetStatusName(status) + " - " +
3057+
cublasGetStatusString(status));
3058+
}
3059+
30543060
template <typename input_scalar_type_, typename output_scalar_type_ = input_scalar_type_>
30553061
static void cublas_tops(bm::State &state) {
30563062
// Matrix size and leading dimensions
@@ -3069,42 +3075,42 @@ static void cublas_tops(bm::State &state) {
30693075

30703076
// cuBLAS handle
30713077
cublasHandle_t handle;
3072-
cublasCreate(&handle);
3078+
cublas_check(cublasCreate(&handle));
30733079

30743080
// Perform the GEMM operation
30753081
// https://docs.nvidia.com/cuda/cublas/#cublas-t-gemm
30763082
for (auto _ : state) {
30773083
if constexpr (std::is_same_v<input_scalar_type_, float> && same_type) {
30783084
input_scalar_type_ alpha = 1, beta = 0;
3079-
cublasSgemm( //
3085+
cublas_check(cublasSgemm( //
30803086
handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, //
30813087
&alpha, a.begin(), lda, b.begin(), ldb, //
3082-
&beta, c.begin(), ldc);
3088+
&beta, c.begin(), ldc));
30833089
}
30843090
else if constexpr (std::is_same_v<input_scalar_type_, double> && same_type) {
30853091
input_scalar_type_ alpha = 1, beta = 0;
3086-
cublasDgemm( //
3092+
cublas_check(cublasDgemm( //
30873093
handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, //
30883094
&alpha, a.begin(), lda, b.begin(), ldb, //
3089-
&beta, c.begin(), ldc);
3095+
&beta, c.begin(), ldc));
30903096
}
30913097
else if constexpr (std::is_same_v<input_scalar_type_, __half> && same_type) {
30923098
input_scalar_type_ alpha = 1, beta = 0;
3093-
cublasHgemm( //
3099+
cublas_check(cublasHgemm( //
30943100
handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, //
30953101
&alpha, a.begin(), lda, b.begin(), ldb, //
3096-
&beta, c.begin(), ldc);
3102+
&beta, c.begin(), ldc));
30973103
}
30983104
else if constexpr (std::is_same_v<input_scalar_type_, int8_t> && std::is_same_v<output_scalar_type_, int32_t>) {
30993105
// Scaling factors must correspond to the accumulator type
31003106
// https://docs.nvidia.com/cuda/cublas/#cublasgemmex
31013107
int32_t alpha_int = 1, beta_int = 0;
3102-
cublasGemmEx( //
3108+
cublas_check(cublasGemmEx( //
31033109
handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n, //
31043110
&alpha_int, a.begin(), CUDA_R_8I, lda, //
31053111
b.begin(), CUDA_R_8I, ldb, //
31063112
&beta_int, c.begin(), CUDA_R_32I, ldc, //
3107-
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
3113+
CUDA_R_32I, CUBLAS_GEMM_DEFAULT));
31083114
}
31093115
// Trigger a compile-time error for unsupported type combinations
31103116
else {
@@ -3122,7 +3128,7 @@ static void cublas_tops(bm::State &state) {
31223128
state.SetComplexityN(n);
31233129

31243130
// Cleanup
3125-
cublasDestroy(handle);
3131+
cublas_check(cublasDestroy(handle));
31263132
}
31273133

31283134
// Register benchmarks

0 commit comments

Comments
 (0)