Skip to content

Commit 8cbe2c9

Browse files
committed
sycl: reordered Q4_K MMVQ
1 parent 52b1622 commit 8cbe2c9

File tree

7 files changed

+278
-87
lines changed

7 files changed

+278
-87
lines changed

ggml/src/ggml-sycl/convert.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
183183
}
184184
}
185185

186+
template <typename dst_t>
187+
static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
188+
const int64_t nb = k / QK_K;
189+
{
190+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
191+
192+
stream->submit([&](sycl::handler & cgh) {
193+
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
194+
cgh.parallel_for(
195+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
196+
[=](sycl::nd_item<3> item_ct1) {
197+
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
198+
});
199+
});
200+
}
201+
}
202+
186203
template <typename dst_t>
187204
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
188205
dpct::queue_ptr stream) {
@@ -493,7 +510,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
493510
case GGML_TYPE_Q3_K:
494511
return dequantize_row_q3_K_sycl;
495512
case GGML_TYPE_Q4_K:
496-
return dequantize_row_q4_K_sycl;
513+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
514+
return dequantize_row_q4_K_sycl_reorder;
515+
} else {
516+
return dequantize_row_q4_K_sycl;
517+
}
497518
case GGML_TYPE_Q5_K:
498519
return dequantize_row_q5_K_sycl;
499520
case GGML_TYPE_Q6_K:

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,31 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
357357
}
358358
#endif
359359

360+
template <typename dst_t>
361+
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
362+
const float dmin, uint8_t * __restrict__ scales_local,
363+
const sycl::nd_item<3> & item_ct1, int il, int ir) {
364+
const int is = 2 * il;
365+
const int n = 4;
366+
367+
item_ct1.barrier(sycl::access::fence_space::local_space);
368+
369+
uint8_t sc, m;
370+
get_scale_min_k4(is + 0, scales_local, sc, m);
371+
const float d1 = dall * sc;
372+
const float m1 = dmin * m;
373+
374+
get_scale_min_k4(is + 1, scales_local, sc, m);
375+
const float d2 = dall * sc;
376+
const float m2 = dmin * m;
377+
378+
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
379+
for (int l = 0; l < n; ++l) {
380+
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
381+
y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
382+
}
383+
}
384+
360385
template<typename dst_t>
361386
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
362387
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +390,21 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
365390
const int64_t i = item_ct1.get_group(2);
366391

367392
#if QK_K == 256
368-
// assume 32 threads
369393
const int64_t tid = item_ct1.get_local_id(2);
370-
const int64_t il = tid/8;
371-
const int64_t ir = tid%8;
372-
const int64_t is = 2*il;
373-
const int64_t n = 4;
394+
const int64_t il = tid / 8;
395+
const int64_t ir = tid % 8;
374396

375-
dst_t * y = yy + i*QK_K + 64*il + n*ir;
397+
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
376398

377399
const sycl::half2 dm = x[i].dm;
378400
const float dall = dm[0];
379401
const float dmin = dm[1];
380402

381-
if (tid < 12)
403+
if (tid < 12) {
382404
scales_local[tid] = x[i].scales[tid];
383-
item_ct1.barrier(sycl::access::fence_space::local_space);
384-
385-
uint8_t sc, m;
386-
get_scale_min_k4(is + 0, scales_local, sc, m);
387-
const float d1 = dall * sc;
388-
const float m1 = dmin * m;
389-
get_scale_min_k4(is + 1, scales_local, sc, m);
390-
const float d2 = dall * sc;
391-
const float m2 = dmin * m;
392-
393-
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
394-
for (int l = 0; l < n; ++l) {
395-
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
396-
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
397405
}
406+
407+
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, item_ct1, il, ir);
398408
#else
399409
const int64_t tid = item_ct1.get_local_id(2);
400410
const uint8_t * q = x[i].qs;
@@ -406,6 +416,35 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
406416
#endif
407417
}
408418

419+
template <typename dst_t>
420+
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
421+
const sycl::nd_item<3> & item_ct1, int64_t nb) {
422+
const int64_t i = item_ct1.get_group(2);
423+
const int64_t tid = item_ct1.get_local_id(2);
424+
const int64_t il = tid / 8;
425+
const int64_t ir = tid % 8;
426+
427+
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
428+
429+
const uint8_t * base = static_cast<const uint8_t *>(vx);
430+
const size_t qs_offset = i * (QK_K / 2);
431+
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
432+
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
433+
434+
const uint8_t * qs_ptr = base + qs_offset;
435+
const uint8_t * scales_ptr = base + scales_offset;
436+
const ggml_half2 * dm_ptr = reinterpret_cast<const ggml_half2 *>(base + dm_offset);
437+
438+
const float dall = dm_ptr->x();
439+
const float dmin = dm_ptr->y();
440+
441+
if (tid < 12) {
442+
scales_local[tid] = scales_ptr[tid];
443+
}
444+
445+
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, item_ct1, il, ir);
446+
}
447+
409448
template<typename dst_t>
410449
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
411450
const sycl::nd_item<3> &item_ct1) {

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
11291129
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
11301130
break;
11311131
case GGML_TYPE_Q4_K:
1132-
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1132+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1133+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1134+
// reorder is currently not supported for dmmv
1135+
GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
1136+
} else {
1137+
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1138+
}
11331139
break;
11341140
case GGML_TYPE_Q5_K:
11351141
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
337337
assert(tensor->view_src->buffer->buft == buffer->buft);
338338
return GGML_STATUS_SUCCESS;
339339
}
340-
if (tensor->type == GGML_TYPE_Q4_0) {
340+
if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) {
341341
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
342342
tensor->extra = extra;
343343
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -2890,6 +2890,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28902890
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
28912891
switch (type) {
28922892
case GGML_TYPE_Q4_0:
2893+
case GGML_TYPE_Q4_K:
28932894
return true;
28942895
default:
28952896
return false;
@@ -2915,6 +2916,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29152916
}
29162917
}
29172918

2919+
static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2920+
return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2921+
src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
2922+
}
2923+
2924+
static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2925+
return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2926+
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2927+
}
2928+
29182929
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
29192930
ggml_tensor * dst) {
29202931
const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
@@ -2938,14 +2949,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29382949
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
29392950
}
29402951

2952+
// TODO: make these into functions, add mmvq check for reorder
29412953
// check data types and tensor shapes for custom matrix multiplication kernels:
2942-
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
2943-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2944-
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
2954+
bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
29452955

2946-
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
2947-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2948-
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2956+
bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
29492957

29502958
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
29512959
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -3622,15 +3630,14 @@ static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
36223630
SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
36233631

36243632
GGML_UNUSED(backend);
3625-
}
3626-
catch (sycl::exception const &exc) {
3627-
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3628-
<< ", line:" << __LINE__ << std::endl;
3629-
std::exit(1);
3633+
3634+
} catch (const sycl::exception & exc) {
3635+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3636+
std::exit(1);
36303637
}
36313638

3632-
static void reorder_qw(char *data_device, const int ncols, const int nrows,
3633-
size_t size, size_t offset, dpct::queue_ptr stream) {
3639+
static void reorder_qw_q4_0(char * data_device, const int ncols, const int nrows, size_t size, size_t offset,
3640+
dpct::queue_ptr stream) {
36343641
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
36353642
SYCL_CHECK(
36363643
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
@@ -3657,22 +3664,65 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
36573664
sycl::free(tmp_buf, *stream);
36583665
}
36593666

3660-
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3667+
static void reorder_qw_q4_k(char * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3668+
GGML_ASSERT(size % sizeof(block_q4_K) == 0);
3669+
GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
3670+
3671+
const int nblocks = size / sizeof(block_q4_K);
3672+
3673+
auto tmp_buf = sycl::malloc_device<char>(size, *stream);
3674+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
3675+
3676+
auto * qs_ptr = (uint8_t *) data_device;
3677+
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3678+
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3679+
3680+
stream->parallel_for(nblocks, [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3681+
const block_q4_K * x = (const block_q4_K *) tmp_buf;
3682+
const int ib = i;
3683+
3684+
for (int j = 0; j < QK_K / 2; ++j) {
3685+
qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
3686+
}
3687+
3688+
for (int j = 0; j < K_SCALE_SIZE; ++j) {
3689+
scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
3690+
}
3691+
3692+
dm_ptr[ib] = x[ib].dm;
3693+
});
3694+
3695+
sycl::free(tmp_buf, *stream);
3696+
}
3697+
3698+
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream, ggml_type type) {
36613699
char*data_device = (char*)src0->data;
36623700
size_t ncols = src0->ne[0];
36633701
size_t nrows = src0->ne[1];
36643702
size_t size = ggml_nbytes(src0);
36653703

3666-
reorder_qw(data_device, ncols, nrows, size, 0, stream);
3704+
switch (type) {
3705+
case GGML_TYPE_Q4_0:
3706+
reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3707+
break;
3708+
case GGML_TYPE_Q4_K:
3709+
reorder_qw_q4_k(data_device, size, 0, stream);
3710+
break;
3711+
default:
3712+
GGML_SYCL_DEBUG("reorder_qw() called with unsupported type");
3713+
break;
3714+
}
36673715
}
36683716

36693717
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
36703718
ggml_tensor *src0 = dst->src[0];
36713719
ggml_tensor *src1 = dst->src[1];
3720+
const bool is_q4_k_mmvq =
3721+
dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_K && can_use_mul_mat_vec_q(src0, src1, dst);
36723722

3673-
if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3674-
src1->ne[2]==1 && src1->ne[3]==1) {
3675-
reorder_qw(src0, stream);
3723+
if (dst->op == GGML_OP_MUL_MAT && (src0->type == GGML_TYPE_Q4_0 || is_q4_k_mmvq) && src1->ne[2] == 1 &&
3724+
src1->ne[3] == 1) {
3725+
reorder_qw(src0, stream, src0->type);
36763726
ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
36773727
GGML_ASSERT(extra);
36783728
extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
2424
const int blocks_per_row = ncols / block_traits::qk;
2525
constexpr int blocks_per_subgroup = safe_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
2626
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
27+
const int nblocks = nrows * (ncols / block_traits::qk);
2728

2829
assert(blocks_per_subgroup > 0);
2930
assert(block_elements_per_subgroup > 0);
@@ -44,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4445
// x block quant index when casting the quants to int
4546
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
4647

47-
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
48+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
4849
}
4950
}
5051

@@ -738,6 +739,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
738739
}
739740
}
740741

742+
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
743+
const int nrows, dpct::queue_ptr stream) {
744+
GGML_ASSERT(ncols % QK_K == 0);
745+
746+
const int block_num_y = safe_div(nrows, GGML_SYCL_MMV_Y);
747+
constexpr size_t num_subgroups = 16;
748+
GGML_ASSERT(block_num_y % num_subgroups == 0);
749+
750+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
751+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
752+
753+
stream->submit([&](sycl::handler & cgh) {
754+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
755+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
756+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
757+
nrows, nd_item);
758+
});
759+
});
760+
}
761+
762+
741763
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
742764
float *dst, const int ncols,
743765
const int nrows,
@@ -1032,7 +1054,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10321054
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
10331055
break;
10341056
case GGML_TYPE_Q4_K:
1035-
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1057+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1058+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1059+
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1060+
} else {
1061+
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1062+
}
10361063
break;
10371064
case GGML_TYPE_Q5_K:
10381065
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

ggml/src/ggml-sycl/quants.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
5656
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
5757
};
5858

59+
template <> struct block_q_t<GGML_TYPE_Q4_K> {
60+
struct traits {
61+
static constexpr uint32_t qk = QK_K;
62+
static constexpr uint32_t qi = QI4_K;
63+
static constexpr uint32_t qr = QR4_K;
64+
static constexpr uint32_t vdr_mmvq = 2;
65+
};
66+
67+
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
68+
69+
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
70+
auto nblocks = (nrows * (ncols / traits::qk));
71+
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
72+
}
73+
74+
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
75+
76+
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
77+
78+
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
79+
};
80+
5981
} // namespace ggml_sycl_reordered
6082

6183
#endif // GGML_SYCL_QUANTS_HPP

0 commit comments

Comments
 (0)