@@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
3131
3232    float  partial_sum = 0 .0f ;
3333    for  (int  i = sg.get_local_linear_id () / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
34-         const  int  ibx       = row * blocks_per_row + i;  //  x block index
35-         //  TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
36-         const  int  bx_offset = block_type::get_block_offset (ibx);
37-         const  int  d_offset  = block_type::get_d_offset (nrows, ncols, ibx);
34+         const  int  ibx = row * blocks_per_row + i;  //  x block index
3835
36+         const  auto          bx_offset      = block_type::get_block_offset (ibx, nblocks);
37+         const  auto          d_offset       = block_type::get_d_offset (nrows, ncols, ibx);
3938        //  Y block index that aligns with ibx
4039        const  int  iby = i * block_type::block_to_q8_1_ratio ();
4140        const  int8_t * q8_1_quant_ptr = (const  int8_t *)vy + iby * QK8_1;
@@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4645            //  x block quant index when casting the quants to int
4746            const  int  iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id () % block_elements_per_subgroup);
4847
49-             partial_sum += reorder_vec_dot_q_sycl ()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks );
48+             partial_sum += reorder_vec_dot_q_sycl ()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
5049        }
5150    }
5251
@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
785784    }
786785}
787786
787+ static  void  reorder_mul_mat_vec_q6_k_q8_1_sycl (const  void  * vx, const  void  * vy, float  * dst, const  int  ncols,
788+                                                const  int  nrows, dpct::queue_ptr stream) {
789+     GGML_ASSERT (ncols % QK_K == 0 );
790+     const  int         block_num_y   = ceil_div (nrows, GGML_SYCL_MMV_Y);
791+     constexpr  size_t  num_subgroups = 16 ;
792+     GGML_ASSERT (block_num_y % num_subgroups == 0 );
793+ 
794+     const  sycl::range<3 > global_size (1 , GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
795+     const  sycl::range<3 > workgroup_size (1 , GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
796+ 
797+     stream->submit ([&](sycl::handler & cgh) {
798+         cgh.parallel_for (sycl::nd_range<3 >(global_size, workgroup_size),
799+                          [=](sycl::nd_item<3 > nd_item) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
800+                              mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
801+                                                                                            nd_item);
802+                          });
803+     });
804+ }
788805static  void  mul_mat_vec_q6_K_q8_1_sycl (const  void  *vx, const  void  *vy,
789806                                       float  *dst, const  int  ncols,
790807                                       const  int  nrows,
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10701087                mul_mat_vec_q5_K_q8_1_sycl (src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
10711088                break ;
10721089            case  GGML_TYPE_Q6_K:
1073-                 mul_mat_vec_q6_K_q8_1_sycl (src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1090+                 if  ((ggml_tensor_extra_gpu *) dst->src [0 ]->extra  &&
1091+                     ((ggml_tensor_extra_gpu *) dst->src [0 ]->extra )->optimized_feature .reorder ) {
1092+                     GGML_SYCL_DEBUG (" Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n " 
1093+                     reorder_mul_mat_vec_q6_k_q8_1_sycl (src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1094+                 } else  {
1095+                     GGML_SYCL_DEBUG (" Calling mul_mat_vec_q6_k_q8_1_sycl\n " 
1096+                     mul_mat_vec_q6_K_q8_1_sycl (src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1097+                 }
10741098                break ;
10751099            case  GGML_TYPE_IQ1_S:
10761100                mul_mat_vec_iq1_s_q8_1_sycl (src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
0 commit comments