Skip to content

Commit 275a591

Browse files
committed
cont : adopt variable names and comment from the other branch
1 parent 18388c7 commit 275a591

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,16 +1924,17 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19241924
const int64_t r3 = ne13/ne03;
19251925

19261926
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927-
const size_t sa = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1928-
const size_t sb = ne12 == 1 ? s13 : s12;
1927+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929+
const int64_t smb = ne12 == 1 ? s13 : s12;
19291930

19301931
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19311932
// use cublasGemmStridedBatchedEx
19321933
CUBLAS_CHECK(
19331934
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19341935
ne01, ne11, ne10,
1935-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sa, // strideA
1936-
src1_ptr, cu_data_type_b, s11, sb, // strideB
1936+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937+
src1_ptr, cu_data_type_b, s11, smb, // strideB
19371938
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19381939
ne12*ne13,
19391940
cu_compute_type,

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,14 +2858,15 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28582858
#endif
28592859
{
28602860
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
2861-
const size_t sa = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2862-
const size_t sb = ne12 == 1 ? s13 : s12;
2861+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
2862+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
2863+
const int64_t smb = ne12 == 1 ? s13 : s12;
28632864

28642865
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
28652866
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
28662867
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2867-
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sa,
2868-
src1_f16, dpct::library_data_t::real_half, s11, sb, beta, dst_ddf,
2868+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
2869+
src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
28692870
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
28702871
} else {
28712872
const int ne23 = ne12 * ne13;

0 commit comments

Comments
 (0)