Skip to content

Commit 15e92fd

Browse files
authored
cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 (#15038)
* cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 ggml-ci * cont : fix cont types ggml-ci * cont : adopt variable names and comment from the other branch
1 parent 2bf3fbf commit 15e92fd

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18521852
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
18531853
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
18541854

1855+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1856+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1857+
18551858
// Handle src0
18561859
src0_ptr = (const cuda_t *) src0->data;
18571860

@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18701873
s11 = ne10;
18711874
s12 = ne11*s11;
18721875
s13 = ne12*s12;
1876+
1877+
is_src1_cont_2 = true;
18731878
}
18741879

18751880
// Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19181923
const int64_t r2 = ne12/ne02;
19191924
const int64_t r3 = ne13/ne03;
19201925

1921-
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1926+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929+
const int64_t smb = ne12 == 1 ? s13 : s12;
1930+
19221931
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231932
// use cublasGemmStridedBatchedEx
19241933
CUBLAS_CHECK(
19251934
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19261935
ne01, ne11, ne10,
1927-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928-
src1_ptr, cu_data_type_b, s11, s12, // strideB
1929-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1936+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937+
src1_ptr, cu_data_type_b, s11, smb, // strideB
1938+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19301939
ne12*ne13,
19311940
cu_compute_type,
19321941
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,6 +2688,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
26882688
const size_t type_size_src0 = ggml_type_size(src0->type);
26892689
const size_t type_size_src1 = ggml_type_size(src1->type);
26902690

2691+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
2692+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
2693+
26912694
// SRC1 strides
26922695
int64_t s11 = nb11 / type_size_src1;
26932696
int64_t s12 = nb12 / type_size_src1;
@@ -2737,6 +2740,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
27372740
s11 = ne10;
27382741
s12 = ne11 * s11;
27392742
s13 = ne12 * s12;
2743+
2744+
is_src1_cont_2 = true;
27402745
}
27412746

27422747
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
@@ -2852,12 +2857,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28522857
else
28532858
#endif
28542859
{
2855-
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2860+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
2861+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
2862+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2863+
const int64_t smb = ne12 == 1 ? s13 : s12;
2864+
28562865
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
28572866
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
28582867
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2859-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2860-
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2868+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2869+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
28612870
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
28622871
} else {
28632872
const int ne23 = ne12 * ne13;

0 commit comments

Comments
 (0)