@@ -341,7 +341,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
341341 assert (tensor->view_src ->buffer ->buft == buffer->buft );
342342 return GGML_STATUS_SUCCESS;
343343 }
344- if (tensor->type == GGML_TYPE_Q4_0) {
344+ if (tensor->type == GGML_TYPE_Q4_0 || tensor-> type == GGML_TYPE_Q4_K ) {
345345 ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
346346 tensor->extra = extra;
347347 ctx->tensor_extras .push_back (extra); // used to release it when destroy ctx.
@@ -2840,6 +2840,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28402840inline bool ggml_sycl_supports_reorder_dequantize (enum ggml_type type) {
28412841 switch (type) {
28422842 case GGML_TYPE_Q4_0:
2843+ case GGML_TYPE_Q4_K:
28432844 return true ;
28442845 default :
28452846 return false ;
@@ -2858,6 +2859,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
28582859inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
28592860 switch (type) {
28602861 case GGML_TYPE_Q4_0:
2862+ case GGML_TYPE_Q4_K:
28612863 return true ;
28622864 default :
28632865 return false ;
@@ -2883,16 +2885,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
28832885 }
28842886}
28852887
2886- static void reorder_qw ( char * data_device, const int ncols, const int nrows,
2887- size_t size, size_t offset, dpct::queue_ptr stream) {
2888- auto tmp_buf = sycl::malloc_shared<char >(size, *stream);
2888+ static void reorder_qw_q4_0 ( uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset ,
2889+ dpct::queue_ptr stream) {
2890+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
28892891 SYCL_CHECK (
28902892 CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
28912893 .wait ()));
28922894 GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
28932895 GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
28942896 int offset_blks = offset / sizeof (block_q4_0);
2895- auto qs_ptr = ( uint8_t *) data_device + offset_blks * QK4_0 / 2 ;
2897+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2 ;
28962898 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
28972899
28982900 stream->parallel_for (
@@ -2906,18 +2908,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
29062908 *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs [j];
29072909 }
29082910 *(d_ptr + ib) = x[ib].d ;
2909- });
2911+ }).wait_and_throw ();
2912+
2913+ sycl::free (tmp_buf, *stream);
2914+ }
2915+
2916+ static void reorder_qw_q4_k (uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2917+ GGML_ASSERT (size % sizeof (block_q4_K) == 0 );
2918+ GGML_ASSERT (offset % sizeof (block_q4_K) == 0 );
2919+
2920+ const int nblocks = size / sizeof (block_q4_K);
2921+
2922+ auto * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
2923+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
2924+
2925+ auto * qs_ptr = data_device;
2926+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2927+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2928+
2929+ stream->parallel_for (nblocks, [=](auto i) {
2930+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2931+ const int ib = i;
2932+
2933+ for (int j = 0 ; j < QK_K / 2 ; ++j) {
2934+ qs_ptr[ib * (QK_K / 2 ) + j] = x[ib].qs [j];
2935+ }
2936+
2937+ for (int j = 0 ; j < K_SCALE_SIZE; ++j) {
2938+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales [j];
2939+ }
2940+
2941+ dm_ptr[ib] = x[ib].dm ;
2942+ }).wait_and_throw ();
29102943
29112944 sycl::free (tmp_buf, *stream);
29122945}
29132946
29142947static void reorder_qw (const ggml_tensor * src0, dpct::queue_ptr stream) {
2915- char * data_device = (char *) src0->data ;
2948+ uint8_t * data_device = (uint8_t *) src0->data ;
29162949 size_t ncols = src0->ne [0 ];
29172950 size_t nrows = src0->ne [1 ];
29182951 size_t size = ggml_nbytes (src0);
29192952
2920- reorder_qw (data_device, ncols, nrows, size, 0 , stream);
2953+ switch (src0->type ) {
2954+ case GGML_TYPE_Q4_0:
2955+ reorder_qw_q4_0 (data_device, ncols, nrows, size, 0 , stream);
2956+ break ;
2957+ case GGML_TYPE_Q4_K:
2958+ reorder_qw_q4_k (data_device, size, 0 , stream);
2959+ break ;
2960+ default :
2961+ GGML_ABORT (" reorder_qw() called with unsupported type" );
2962+ break ;
2963+ }
29212964}
29222965
29232966static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -2943,8 +2986,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
29432986 }
29442987}
29452988
2946- static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29472989
2990+ static bool can_use_dequantize_mul_mat_vec (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2991+ return ggml_sycl_supports_dmmv (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2992+ src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
2993+ }
2994+
2995+ static bool can_use_mul_mat_vec_q (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2996+ return ggml_is_quantized (src0->type ) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
2997+ src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
2998+ }
2999+
3000+ static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
29483001 const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
29493002 int64_t min_compute_capability = INT_MAX;
29503003
@@ -2966,14 +3019,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29663019 min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
29673020 }
29683021
3022+ // TODO: make these into functions, add mmvq check for reorder
29693023 // check data types and tensor shapes for custom matrix multiplication kernels:
2970- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv (src0->type )
2971- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2972- && src0->ne [0 ] % GGML_SYCL_DMMV_X == 0 && src1->ne [1 ] == 1 ;
3024+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec (src0, src1, dst);
29733025
2974- bool use_mul_mat_vec_q = ggml_is_quantized (src0->type )
2975- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2976- && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3026+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q (src0, src1, dst);
29773027
29783028 bool use_mul_mat_q = ggml_sycl_supports_mmq (src0->type )
29793029 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
0 commit comments