@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18521852 ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
18531853 ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
18541854
1855+ bool is_src0_cont_2 = ggml_is_contiguous (src0);
1856+ bool is_src1_cont_2 = ggml_is_contiguous (src1);
1857+
18551858 // Handle src0
18561859 src0_ptr = (const cuda_t *) src0->data ;
18571860
@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18701873 s11 = ne10;
18711874 s12 = ne11*s11;
18721875 s13 = ne12*s12;
1876+
1877+ is_src1_cont_2 = true ;
18731878 }
18741879
18751880 // Setup destination buffer
@@ -1918,15 +1923,18 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19181923 const int64_t r2 = ne12/ne02;
19191924 const int64_t r3 = ne13/ne03;
19201925
1921- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
1926+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927+ const size_t sa = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1928+ const size_t sb = ne12 == 1 ? s13 : s12;
1929+
19221930 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231931 // use cublasGemmStridedBatchedEx
19241932 CUBLAS_CHECK (
19251933 cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19261934 ne01, ne11, ne10,
1927- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928- src1_ptr, cu_data_type_b, s11, s12, // strideB
1929- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1935+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sa, // strideA
1936+ src1_ptr, cu_data_type_b, s11, sb, // strideB
1937+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
19301938 ne12*ne13,
19311939 cu_compute_type,
19321940 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
0 commit comments