@@ -337,7 +337,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
337337 assert (tensor->view_src ->buffer ->buft == buffer->buft );
338338 return GGML_STATUS_SUCCESS;
339339 }
340- if (tensor->type == GGML_TYPE_Q4_0) {
340+ if (tensor->type == GGML_TYPE_Q4_0 || tensor-> type == GGML_TYPE_Q4_K ) {
341341 ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
342342 tensor->extra = extra;
343343 ctx->tensor_extras .push_back (extra); // used to release it when destroy ctx.
@@ -2890,6 +2890,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28902890inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
28912891 switch (type) {
28922892 case GGML_TYPE_Q4_0:
2893+ case GGML_TYPE_Q4_K:
28932894 return true ;
28942895 default :
28952896 return false ;
@@ -2915,6 +2916,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29152916 }
29162917}
29172918
2919+ static bool can_use_dequantize_mul_mat_vec (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2920+ return ggml_sycl_supports_dmmv (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2921+ src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
2922+ }
2923+
2924+ static bool can_use_mul_mat_vec_q (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2925+ return ggml_is_quantized (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2926+ src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
2927+ }
2928+
29182929static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
29192930 ggml_tensor * dst) {
29202931 const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
@@ -2938,14 +2949,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29382949 min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
29392950 }
29402951
2952+ // TODO: make these into functions, add mmvq check for reorder
29412953 // check data types and tensor shapes for custom matrix multiplication kernels:
2942- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv (src0->type )
2943- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2944- && src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
2954+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec (src0, src1, dst);
29452955
2946- bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
2947- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2948- && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
2956+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q (src0, src1, dst);
29492957
29502958 bool use_mul_mat_q = ggml_sycl_supports_mmq (src0->type )
29512959 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -3622,15 +3630,14 @@ static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
36223630 SYCL_CHECK (CHECK_TRY_ERROR ((stream)->wait ()));
36233631
36243632 GGML_UNUSED (backend);
3625- }
3626- catch (sycl::exception const &exc) {
3627- std::cerr << exc.what () << " Exception caught at file:" << __FILE__
3628- << " , line:" << __LINE__ << std::endl;
3629- std::exit (1 );
3633+
3634+ } catch (const sycl::exception & exc) {
3635+ std::cerr << exc.what () << " Exception caught at file:" << __FILE__ << " , line:" << __LINE__ << std::endl;
3636+ std::exit (1 );
36303637}
36313638
3632- static void reorder_qw (char *data_device, const int ncols, const int nrows,
3633- size_t size, size_t offset, dpct::queue_ptr stream) {
3639+ static void reorder_qw_q4_0 (char * data_device, const int ncols, const int nrows, size_t size, size_t offset ,
3640+ dpct::queue_ptr stream) {
36343641 auto tmp_buf = sycl::malloc_shared<char >(size, *stream);
36353642 SYCL_CHECK (
36363643 CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
@@ -3657,22 +3664,65 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
36573664 sycl::free (tmp_buf, *stream);
36583665}
36593666
3660- static void reorder_qw (ggml_tensor * src0, dpct::queue_ptr stream) {
3667+ static void reorder_qw_q4_k (char * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3668+ GGML_ASSERT (size % sizeof (block_q4_K) == 0 );
3669+ GGML_ASSERT (offset % sizeof (block_q4_K) == 0 );
3670+
3671+ const int nblocks = size / sizeof (block_q4_K);
3672+
3673+ auto tmp_buf = sycl::malloc_device<char >(size, *stream);
3674+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
3675+
3676+ auto * qs_ptr = (uint8_t *) data_device;
3677+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
3678+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
3679+
3680+ stream->parallel_for (nblocks, [=](auto i) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
3681+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
3682+ const int ib = i;
3683+
3684+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
3685+ qs_ptr[ib * (QK_K / 2 ) + j] = x[ib].qs [j];
3686+ }
3687+
3688+ for (int j = 0 ; j < K_SCALE_SIZE; ++j) {
3689+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales [j];
3690+ }
3691+
3692+ dm_ptr[ib] = x[ib].dm ;
3693+ });
3694+
3695+ sycl::free (tmp_buf, *stream);
3696+ }
3697+
3698+ static void reorder_qw (ggml_tensor * src0, dpct::queue_ptr stream, ggml_type type) {
36613699 char *data_device = (char *)src0->data ;
36623700 size_t ncols = src0->ne [0 ];
36633701 size_t nrows = src0->ne [1 ];
36643702 size_t size = ggml_nbytes (src0);
36653703
3666- reorder_qw (data_device, ncols, nrows, size, 0 , stream);
3704+ switch (type) {
3705+ case GGML_TYPE_Q4_0:
3706+ reorder_qw_q4_0 (data_device, ncols, nrows, size, 0 , stream);
3707+ break ;
3708+ case GGML_TYPE_Q4_K:
3709+ reorder_qw_q4_k (data_device, size, 0 , stream);
3710+ break ;
3711+ default :
3712+ GGML_SYCL_DEBUG (" reorder_qw() called with unsupported type" );
3713+ break ;
3714+ }
36673715}
36683716
36693717static void opt_for_reorder (ggml_tensor * dst, dpct::queue_ptr stream) {
36703718 ggml_tensor *src0 = dst->src [0 ];
36713719 ggml_tensor *src1 = dst->src [1 ];
3720+ const bool is_q4_k_mmvq =
3721+ dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_K && can_use_mul_mat_vec_q (src0, src1, dst);
36723722
3673- if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3674- src1->ne [2 ]== 1 && src1-> ne [ 3 ]== 1 ) {
3675- reorder_qw (src0, stream);
3723+ if (dst->op == GGML_OP_MUL_MAT && ( src0->type == GGML_TYPE_Q4_0 || is_q4_k_mmvq) && src1-> ne [ 2 ] == 1 &&
3724+ src1->ne [3 ] == 1 ) {
3725+ reorder_qw (src0, stream, src0-> type );
36763726 ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
36773727 GGML_ASSERT (extra);
36783728 extra->optimized_feature .reorder = true ; // used to decode/dequan in next steps.
0 commit comments