@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18521852    ggml_cuda_pool_alloc<cuda_t > src0_alloc (ctx.pool ());
18531853    ggml_cuda_pool_alloc<cuda_t > src1_alloc (ctx.pool ());
18541854
1855+     bool  is_src0_cont_2 = ggml_is_contiguous_2 (src0);
1856+     bool  is_src1_cont_2 = ggml_is_contiguous_2 (src1);
1857+ 
18551858    //  Handle src0
18561859    src0_ptr = (const  cuda_t  *) src0->data ;
18571860
@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18701873        s11 = ne10;
18711874        s12 = ne11*s11;
18721875        s13 = ne12*s12;
1876+ 
1877+         is_src1_cont_2 = true ;
18731878    }
18741879
18751880    //  Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19181923    const  int64_t  r2 = ne12/ne02;
19191924    const  int64_t  r3 = ne13/ne03;
19201925
1921-     if  (r2 == 1  && r3 == 1  && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
1926+     if  (r2 == 1  && r3 == 1  && is_src0_cont_2 && is_src1_cont_2) {
1927+         //  with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928+         const  int64_t  sma = ne02 == 1  ? nb03/nb00 : nb02/nb00;
1929+         const  int64_t  smb = ne12 == 1  ? s13       : s12;
1930+ 
19221931        //  there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231932        //  use cublasGemmStridedBatchedEx
19241933        CUBLAS_CHECK (
19251934        cublasGemmStridedBatchedEx (ctx.cublas_handle (), CUBLAS_OP_T, CUBLAS_OP_N,
19261935                ne01, ne11, ne10,
1927-                 alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00,  //  strideA
1928-                        src1_ptr, cu_data_type_b, s11,       s12,        //  strideB
1929-                 beta,     dst_t , cu_data_type,   ne0,       ne1*ne0,    //  strideC
1936+                 alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma,      //  strideA
1937+                        src1_ptr, cu_data_type_b, s11,       smb,      //  strideB
1938+                 beta,     dst_t , cu_data_type,   ne0,       ne1*ne0, //  strideC
19301939                ne12*ne13,
19311940                cu_compute_type,
19321941                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
0 commit comments