@@ -352,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
352352 assert (tensor->view_src ->buffer ->buft == buffer->buft );
353353 return GGML_STATUS_SUCCESS;
354354 }
355- if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
355+ if (( tensor->type == GGML_TYPE_Q4_0 || tensor-> type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
356356 ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
357357 tensor->extra = extra;
358358 ctx->tensor_extras .push_back (extra); // used to release it when destroy ctx.
@@ -2900,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
29002900 switch (type) {
29012901 case GGML_TYPE_Q4_0:
29022902 return true ;
2903+ case GGML_TYPE_Q4_K:
2904+ return !g_ggml_sycl_prioritize_dmmv;
29032905 default :
29042906 return false ;
29052907 }
@@ -2917,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
29172919inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
29182920 switch (type) {
29192921 case GGML_TYPE_Q4_0:
2922+ case GGML_TYPE_Q4_K:
29202923 return true ;
29212924 default :
29222925 return false ;
@@ -2942,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
29422945 }
29432946}
29442947
2945- static void reorder_qw ( char * data_device, const int ncols, const int nrows,
2946- size_t size, size_t offset, dpct::queue_ptr stream) {
2947- auto tmp_buf = sycl::malloc_shared<char >(size, *stream);
2948+ static void reorder_qw_q4_0 ( uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset ,
2949+ dpct::queue_ptr stream) {
2950+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
29482951 SYCL_CHECK (
29492952 CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
29502953 .wait ()));
29512954 GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
29522955 GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
29532956 int offset_blks = offset / sizeof (block_q4_0);
2954- auto qs_ptr = ( uint8_t *) data_device + offset_blks * QK4_0 / 2 ;
2957+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2 ;
29552958 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
29562959
29572960 stream->parallel_for (
@@ -2965,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
29652968 *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs [j];
29662969 }
29672970 *(d_ptr + ib) = x[ib].d ;
2968- });
2971+ }).wait_and_throw ();
2972+
2973+ sycl::free (tmp_buf, *stream);
2974+ }
2975+
2976+ static void reorder_qw_q4_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2977+ GGML_ASSERT (size % sizeof (block_q4_K) == 0 );
2978+ GGML_ASSERT (offset % sizeof (block_q4_K) == 0 );
2979+
2980+ const int nblocks = size / sizeof (block_q4_K);
2981+
2982+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
2983+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
2984+
2985+ auto * qs_ptr = data_device;
2986+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2987+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2988+
2989+ stream->parallel_for (nblocks, [=](auto i) {
2990+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2991+ const int ib = i;
2992+
2993+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
2994+ qs_ptr[ib * (QK_K / 2 ) + j] = x[ib].qs [j];
2995+ }
2996+
2997+ for (int j = 0 ; j < K_SCALE_SIZE; ++j) {
2998+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales [j];
2999+ }
3000+
3001+ dm_ptr[ib] = x[ib].dm ;
3002+ }).wait_and_throw ();
29693003
29703004 sycl::free (tmp_buf, *stream);
29713005}
29723006
29733007static void reorder_qw (const ggml_tensor * src0, dpct::queue_ptr stream) {
2974- char * data_device = (char *) src0->data ;
3008+ uint8_t * data_device = (uint8_t *) src0->data ;
29753009 size_t ncols = src0->ne [0 ];
29763010 size_t nrows = src0->ne [1 ];
29773011 size_t size = ggml_nbytes (src0);
29783012
2979- reorder_qw (data_device, ncols, nrows, size, 0 , stream);
3013+ switch (src0->type ) {
3014+ case GGML_TYPE_Q4_0:
3015+ reorder_qw_q4_0 (data_device, ncols, nrows, size, 0 , stream);
3016+ break ;
3017+ case GGML_TYPE_Q4_K:
3018+ reorder_qw_q4_k (data_device, size, 0 , stream);
3019+ break ;
3020+ default :
3021+ GGML_ABORT (" reorder_qw() called with unsupported type" );
3022+ break ;
3023+ }
29803024}
29813025
29823026static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -3019,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
30193063 extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
30203064}
30213065
3022- static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
30233066
3067+ static bool can_use_dequantize_mul_mat_vec (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3068+ return ggml_sycl_supports_dmmv (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3069+ src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
3070+ }
3071+
3072+ static bool can_use_mul_mat_vec_q (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3073+ return ggml_is_quantized (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3074+ src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3075+ }
3076+
3077+ static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
30243078 const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
30253079 int64_t min_compute_capability = INT_MAX;
30263080
@@ -3043,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
30433097 }
30443098
30453099 // check data types and tensor shapes for custom matrix multiplication kernels:
3046- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv (src0->type )
3047- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3048- && src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
3100+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec (src0, src1, dst);
30493101
3050- bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
3051- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
3052- && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3102+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q (src0, src1, dst);
30533103
30543104 bool use_mul_mat_q = ggml_sycl_supports_mmq (src0->type )
30553105 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
0 commit comments