Skip to content

Commit aefe880

Browse files
committed
Reapply "CUDA: fix FP16 cuBLAS GEMM (ggml-org#11396)"
This reverts commit f0e6b2a.
1 parent d1d43c5 commit aefe880

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,8 +1263,8 @@ static void ggml_cuda_op_mul_mat_cublas(
12631263
CUBLAS_CHECK(
12641264
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
12651265
row_diff, src1_ncols, ne10,
1266-
&alpha, src0_ptr, CUDA_R_16F, ne00,
1267-
src1_ptr, CUDA_R_16F, ne10,
1266+
&alpha, src0_ptr, CUDA_R_16F, ne00,
1267+
src1_ptr, CUDA_R_16F, ne10,
12681268
&beta, dst_dd_i, CUDA_R_32F, ldc,
12691269
CUBLAS_COMPUTE_32F,
12701270
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1277,9 +1277,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12771277
CUBLAS_CHECK(
12781278
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
12791279
row_diff, src1_ncols, ne10,
1280-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1281-
src1_ptr, CUDA_R_16F, ne10,
1282-
&beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1280+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1281+
src1_ptr, CUDA_R_16F, ne10,
1282+
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
12831283
CUBLAS_COMPUTE_16F,
12841284
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
12851285

0 commit comments

Comments
 (0)