@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1852
1852
ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
1853
1853
ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
1854
1854
1855
+ bool is_src0_cont_2 = ggml_is_contiguous_2 (src0);
1856
+ bool is_src1_cont_2 = ggml_is_contiguous_2 (src1);
1857
+
1855
1858
// Handle src0
1856
1859
src0_ptr = (const cuda_t *) src0->data ;
1857
1860
@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1870
1873
s11 = ne10;
1871
1874
s12 = ne11*s11;
1872
1875
s13 = ne12*s12;
1876
+
1877
+ is_src1_cont_2 = true ;
1873
1878
}
1874
1879
1875
1880
// Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
1918
1923
const int64_t r2 = ne12/ne02;
1919
1924
const int64_t r3 = ne13/ne03;
1920
1925
1921
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
1926
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929
+ const int64_t smb = ne12 == 1 ? s13 : s12;
1930
+
1922
1931
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
1923
1932
// use cublasGemmStridedBatchedEx
1924
1933
CUBLAS_CHECK (
1925
1934
cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
1926
1935
ne01, ne11, ne10,
1927
- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928
- src1_ptr, cu_data_type_b, s11, s12, // strideB
1929
- beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1936
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
1938
+ beta, dst_t , cu_data_type, ne0, ne1*ne0, // strideC
1930
1939
ne12*ne13,
1931
1940
cu_compute_type,
1932
1941
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
0 commit comments