Skip to content

Commit a043ff1

Browse files
CUDA: fix strided GEMM for [0,2,1,3] per && ne2==1
1 parent 711d5e6 commit a043ff1

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,12 +1921,17 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19211921
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
19221922
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231923
// use cublasGemmStridedBatchedEx
1924+
1925+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1926+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1927+
const int64_t smb = ne02 == 1 ? s13 : s12;
1928+
19241929
CUBLAS_CHECK(
19251930
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19261931
ne01, ne11, ne10,
1927-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928-
src1_ptr, cu_data_type_b, s11, s12, // strideB
1929-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1932+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1933+
src1_ptr, cu_data_type_b, s11, smb, // strideB
1934+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19301935
ne12*ne13,
19311936
cu_compute_type,
19321937
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

0 commit comments

Comments
 (0)