Skip to content

Commit 32098df

Browse files
committed
Revert "cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 (ggml-org#15038)"
1 parent 2a6fe6e commit 32098df

File tree

1 file changed

+4
-13
lines changed

1 file changed

+4
-13
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,9 +1853,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18531853
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
18541854
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
18551855

1856-
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1857-
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1858-
18591856
// Handle src0
18601857
src0_ptr = (const cuda_t *) src0->data;
18611858

@@ -1874,8 +1871,6 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18741871
s11 = ne10;
18751872
s12 = ne11*s11;
18761873
s13 = ne12*s12;
1877-
1878-
is_src1_cont_2 = true;
18791874
}
18801875

18811876
// Setup destination buffer
@@ -1924,19 +1919,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19241919
const int64_t r2 = ne12/ne02;
19251920
const int64_t r3 = ne13/ne03;
19261921

1927-
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1928-
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1929-
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1930-
const int64_t smb = ne12 == 1 ? s13 : s12;
1931-
1922+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
19321923
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19331924
// use cublasGemmStridedBatchedEx
19341925
CUBLAS_CHECK(
19351926
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19361927
ne01, ne11, ne10,
1937-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1938-
src1_ptr, cu_data_type_b, s11, smb, // strideB
1939-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1928+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1929+
src1_ptr, cu_data_type_b, s11, s12, // strideB
1930+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19401931
ne12*ne13,
19411932
cu_compute_type,
19421933
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

0 commit comments

Comments
 (0)