Skip to content

Commit 747c12e

Browse files
committed
sycl: Batched mulmat rework for oneDNN dispatch
1 parent 704bb7a commit 747c12e

File tree

2 files changed

+142
-79
lines changed

2 files changed

+142
-79
lines changed

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,39 +32,28 @@ class DnnlGemmWrapper {
3232
else static_assert(0);
3333
}
3434

35-
// matrix A has m rows, k columns
36-
// matrix B has k rows, n columns
37-
// nra - number of elements to skip when moving into next row in A
38-
// nrb - number of elements to skip when moving into next row in B
39-
// nca - number of elements to skip when moving into next column in A
40-
// ncb - number of elements to skip when moving into next column in B
41-
// stride_a - number of elements to skip when moving to next A matrix
42-
// stride_b - number of elements to skip when moving to next B matrix
43-
// batches_a - number of A matrices
44-
// batches_b - number of B matrices
4535
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46-
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47-
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
36+
const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
37+
const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
4838
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
4939

50-
auto stream = ctx.stream_dnnl(q);
40+
auto stream = ctx.stream_dnnl(q);
5141
auto eng = ctx.engine_dnnl(q);
52-
53-
// { # strides, # rows, # columns }
54-
dnnl::memory::dims a_dims = { batches_a, m, k };
55-
dnnl::memory::dims b_dims = { batches_b, k, n };
56-
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57-
58-
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59-
dnnl::memory::dims a_strides = { stride_a, nra, nca };
60-
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61-
62-
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
63-
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64-
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
65-
42+
43+
dnnl::memory::dims a_dims = {batches_a, m, k };
44+
dnnl::memory::dims a_strides = {stra2, stra1, stra0};
45+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
46+
47+
dnnl::memory::dims b_dims = {batches_b, k, n };
48+
dnnl::memory::dims b_strides = {strb2, strb0, strb1};
49+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
50+
51+
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
52+
dnnl::memory::dims c_strides = {m*n, 1, m };
53+
const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides);
6654
dnnl::primitive_attr primitive_attr;
6755
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
56+
6857
#ifdef GGML_SYCL_F16
6958
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
7059
#endif
@@ -75,25 +64,24 @@ class DnnlGemmWrapper {
7564
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
7665

7766
auto scratchpad_md = matmul_pd.scratchpad_desc();
78-
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
67+
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
68+
7969
auto matmul_prim = dnnl::matmul(matmul_pd);
8070

8171
std::unordered_map<int, dnnl::memory> matmul_args;
8272
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
8373
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
74+
8475
matmul_args.insert({ DNNL_ARG_DST, c_mem });
8576
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
8677

8778
matmul_prim.execute(stream, matmul_args);
8879
}
8980

90-
// matrices A and B are column major, both having k rows
91-
// matrix A has m column, matrix B has n columns
92-
// output: column major matrix C = A transposed * B
9381
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
9482
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
9583

96-
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
84+
gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
9785
}
9886
};
9987

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 122 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,8 +1546,9 @@ static void mul_mat_p021_f16_f32(
15461546

15471547
static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
15481548
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1549-
const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
1549+
const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
15501550
const sycl::nd_item<3> &item_ct1) {
1551+
15511552

15521553
const sycl::half *x = (const sycl::half *)vx;
15531554

@@ -1557,7 +1558,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
15571558
item_ct1.get_local_id(0);
15581559
const int channel_x = channel / channel_x_divisor;
15591560

1560-
const int nrows_y = ncols_x;
15611561
const int nrows_dst = nrows_x;
15621562
const int row_dst = row_x;
15631563

@@ -1576,7 +1576,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
15761576
const int row_y = col_x;
15771577

15781578
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1579-
const int iy = channel*nrows_y + row_y;
1579+
const int iy = channel * channel_stride_y + row_y;
15801580

15811581
const float xi =
15821582
sycl::vec<sycl::half, 1>(x[ix])
@@ -1823,7 +1823,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
18231823
static void ggml_mul_mat_vec_nc_f16_f32_sycl(
18241824
const void *vx, const float *y, float *dst, const int ncols_x,
18251825
const int nrows_x, const int row_stride_x, const int nchannels_x,
1826-
const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
1826+
const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
18271827

18281828
const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
18291829
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1835,7 +1835,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
18351835
sycl::nd_range<3>(block_nums * block_dims, block_dims),
18361836
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
18371837
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
1838-
row_stride_x, channel_stride_x,
1838+
row_stride_x, channel_stride_x, channel_stride_y,
18391839
nchannels_y / nchannels_x, item_ct1);
18401840
});
18411841
}
@@ -2124,8 +2124,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
21242124

21252125
#if GGML_SYCL_DNNL
21262126
if (!g_ggml_sycl_disable_dnn) {
2127-
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2128-
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2127+
DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
2128+
DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
21292129
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
21302130
}
21312131
else
@@ -2171,8 +2171,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
21712171

21722172
#if GGML_SYCL_DNNL
21732173
if (!g_ggml_sycl_disable_dnn) {
2174-
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2175-
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2174+
DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
2175+
DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
21762176
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
21772177
}
21782178
else
@@ -2776,6 +2776,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
27762776
const int64_t nb02 = src0->nb[2];
27772777

27782778
const int64_t ne12 = src1->ne[2];
2779+
const int64_t nb11 = src1->nb[1];
27792780

27802781
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
27812782
queue_ptr main_stream = ctx.stream();
@@ -2786,8 +2787,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
27862787

27872788
const int64_t row_stride_x = nb01 / sizeof(sycl::half);
27882789
const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
2790+
const int64_t channel_stride_y = nb11 / sizeof(float);
27892791

2790-
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
2792+
ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
27912793
}
27922794
catch (sycl::exception const &exc) {
27932795
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -2841,8 +2843,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28412843
float * dst_ddf = static_cast<float *>(dst->data);
28422844

28432845
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2846+
const size_t type_size_src0 = ggml_type_size(src0->type);
28442847
const size_t type_size_src1 = ggml_type_size(src1->type);
2845-
GGML_ASSERT(nb10 == type_size_src1);
28462848

28472849
// SRC1 strides
28482850
int64_t s11 = nb11 / type_size_src1;
@@ -2854,11 +2856,33 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28542856
if (src1->type != GGML_TYPE_F16) {
28552857
scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2,
28562858
" : converting src1 to fp16");
2857-
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2858-
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2859-
const int64_t ne_src1 = ggml_nelements(src1);
2860-
src1_f16_alloc.alloc(ne_src1);
2861-
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2859+
2860+
2861+
// iterate tensor dims and find the slowest moving dim and stride
2862+
int64_t last_dim=0;
2863+
int64_t last_str=0;
2864+
int64_t largest_str=0;
2865+
for(int i = 0; i< 4; i++){
2866+
// last stride is always the largest
2867+
if(src1->nb[i] == largest_str){
2868+
if(src1->ne[last_dim] == 1){
2869+
last_str = i;
2870+
last_dim = i;
2871+
}
2872+
}
2873+
if(src1->nb[i] > largest_str){
2874+
largest_str = src1->nb[i];
2875+
last_str = i;
2876+
last_dim = i;
2877+
}
2878+
2879+
}
2880+
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
2881+
src1_f16_alloc.alloc(ne_src1);
2882+
2883+
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2884+
GGML_ASSERT(to_fp16_sycl != nullptr);
2885+
to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue);
28622886

28632887
src1_f16 = src1_f16_alloc.get();
28642888
s11 = ne10;
@@ -2892,38 +2916,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28922916

28932917
#if GGML_SYCL_DNNL
28942918
if (!g_ggml_sycl_disable_dnn) {
2895-
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2896-
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2897-
2898-
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2899-
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2900-
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2901-
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2902-
};
2903-
2904-
if (r2 == 1 && r3 == 1) {
2905-
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2906-
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2907-
}
2908-
else {
2909-
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2910-
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2911-
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2912-
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2913-
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2919+
int64_t str_a0 = nb00 / type_size_src0;
2920+
int64_t str_a1 = nb01 / type_size_src0;
2921+
int64_t str_a2 = nb02 / type_size_src0;
2922+
2923+
int64_t str_b0 = nb10 / type_size_src1;
2924+
int64_t str_b1 = nb11 / type_size_src1;
2925+
int64_t str_b2 = nb12 / type_size_src1;
2926+
2927+
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2928+
const sycl::half *src1, float *dst,
2929+
int64_t a0, int64_t a1, int64_t batcha,
2930+
int64_t b0, int64_t b1, int64_t batchb,
2931+
int64_t sa0, int64_t sa1, int64_t sa2,
2932+
int64_t sb0, int64_t sb1, int64_t sb2,
2933+
int64_t sd2) {
2934+
bool supported_broadcast = batchb == batcha ? true
2935+
: batchb == 1 || batcha == 1 ? true
2936+
: false;
2937+
if (supported_broadcast) {
2938+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0,
2939+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
2940+
DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
2941+
DnnlGemmWrapper::to_dt<float>(), queue, batcha, batchb);
2942+
} else {
2943+
// iterate over batches from smaller set of matrices (matrix 0)
2944+
int64_t batches0 = batcha;
2945+
int64_t batches1 = batchb;
2946+
2947+
if (batches0 > batches1) {
2948+
int64_t num_mul_mats = batches1;
2949+
int64_t sub_batch = batches0 / num_mul_mats;
2950+
// src0 is batched and bigger, shift and multiply with src1
2951+
for (int64_t i0 = 0; i0 < num_mul_mats; i0++) {
2952+
const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
2953+
const sycl::half *src1_shifted = src1 + (sb2 * i0);
2954+
float *dst_shifted = dst + (sd2 * i0 * sub_batch);
2955+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2956+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2957+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2958+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2959+
queue, sub_batch, 1);
2960+
}
2961+
} else {
2962+
int64_t num_mul_mats = batches0;
2963+
int64_t sub_batch = batches1 / num_mul_mats;
2964+
// src1 is batched and bigger, shift and multiply with src0
2965+
for (int64_t i1 = 0; i1 < num_mul_mats; i1++) {
2966+
const sycl::half *src0_shifted = src0 + (sa2 * i1);
2967+
const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
2968+
float *dst_shifted = dst + (sd2 * i1 * sub_batch);
2969+
DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted,
2970+
DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2971+
src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2972+
sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float>(),
2973+
queue, 1, sub_batch);
2974+
}
2975+
}
29142976
}
2915-
}
2916-
} else {
2917-
// iterate over batches from smaller set of matrices (matrix 0)
2918-
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2919-
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2920-
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2921-
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2922-
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2923-
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2977+
};
2978+
2979+
bool cont_batches_a = nb02 * ne02 == nb03;
2980+
bool cont_batches_b = nb12 * ne12 == nb13;
2981+
if (cont_batches_a && cont_batches_b) {
2982+
int64_t batches0 = ne02 * ne03;
2983+
int64_t batches1 = ne12 * ne13;
2984+
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2985+
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2986+
str_b2, nb2 / sizeof(float));
2987+
} else {
2988+
for (int64_t b_a = 0; b_a < ne03; b_a++) {
2989+
const sycl::half *src0_f16_shifted
2990+
= src0_f16 + (nb03 * b_a / type_size_src0);
2991+
const sycl::half *src1_f16_shifted
2992+
= src1_f16 + (nb13 * b_a / type_size_src1);
2993+
float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float));
2994+
int64_t batches0 = ne02;
2995+
int64_t batches1 = ne12;
2996+
launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted,
2997+
ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
2998+
str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float));
29242999
}
29253000
}
2926-
}
3001+
29273002
}
29283003
else
29293004
#endif
@@ -3263,10 +3338,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
32633338
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
32643339
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
32653340
}
3266-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3341+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
32673342
// KQV single-batch
32683343
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3269-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3344+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) {
32703345
// KQ + KQV multi-batch
32713346
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
32723347
} else if (use_dequantize_mul_mat_vec) {

0 commit comments

Comments
 (0)