@@ -91,6 +91,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
9191 }
9292}
9393
94+ template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder>
95+ static void dequantize_mul_mat_vec_reorder (const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
96+ const sycl::nd_item<3 > &item_ct1) {
97+ // qk = quantized weights per x block
98+ // qr = number of quantized weights per data value in x block
99+ const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
100+ item_ct1.get_local_id (1 );
101+
102+ if (row >= nrows) {
103+ return ;
104+ }
105+
106+ const int tid = item_ct1.get_local_id (2 );
107+
108+
109+ const int ncols_left = ncols % (QK4_0*WARP_SIZE);
110+ const int ncols_align = ncols - ncols_left;
111+ const int iter_stride = 8 *2 *GGML_SYCL_DMMV_X;
112+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
113+ const int y_offset = qr == 1 ? 1 : qk/2 ;
114+
115+ // partial sum for each thread
116+ #ifdef GGML_SYCL_F16
117+ sycl::half2 tmp = {0 .0f , 0 .0f }; // two sums for f16 to take advantage of half2 intrinsics
118+ #else
119+ float tmp = 0 .0f ;
120+ #endif // GGML_SYCL_F16
121+ const char *d_ptr = (const char *)vx+ncols*nrows/2 ;
122+ int i=0 ;
123+ for (i = 0 ; i < ncols_align; i += iter_stride) {
124+ const int col = i + vals_per_iter*tid;
125+ const int ib = (row*ncols + col)/qk; // x block index
126+ const int iqs = (col%qk)/qr; // x quant index
127+ const int iybs = col - col%qk; // y block start index
128+
129+ // processing >2 values per i iter is faster for fast GPUs
130+ #pragma unroll
131+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
132+ // process 2 vals per j iter
133+
134+ // dequantize
135+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
136+ dfloat2 v;
137+ dequantize_kernel_recorder ((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
138+
139+ // matrix multiplication
140+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
141+ #ifdef GGML_SYCL_F16
142+ dfloat2 t1{y[iybs + iqs + j / qr + 0 ],
143+ y[iybs + iqs + j / qr + y_offset]};
144+
145+ tmp += v * t1;
146+ #else
147+ tmp += v.x () * y[iybs + iqs + j / qr + 0 ];
148+ tmp += v.y () * y[iybs + iqs + j / qr + y_offset];
149+ #endif // GGML_SYCL_F16
150+ }
151+ }
152+
153+ for (; i < ncols; i += iter_stride) {
154+ if (tid>=ncols_left/QK4_0) continue ;
155+ const int col = i + vals_per_iter*tid;
156+ const int ib = (row*ncols + col)/qk; // x block index
157+ const int iqs = (col%qk)/qr; // x quant index
158+ const int iybs = col - col%qk; // y block start index
159+
160+ // processing >2 values per i iter is faster for fast GPUs
161+ #pragma unroll
162+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
163+ // process 2 vals per j iter
164+
165+ // dequantize
166+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
167+ dfloat2 v;
168+ dequantize_kernel_recorder ((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
169+
170+ // matrix multiplication
171+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
172+ #ifdef GGML_SYCL_F16
173+ dfloat2 t1{y[iybs + iqs + j / qr + 0 ],
174+ y[iybs + iqs + j / qr + y_offset]};
175+
176+ tmp += v * t1;
177+ #else
178+ tmp += v.x () * y[iybs + iqs + j / qr + 0 ];
179+ tmp += v.y () * y[iybs + iqs + j / qr + y_offset];
180+ #endif // GGML_SYCL_F16
181+ }
182+ }
183+
184+ // sum up partial sums and write back result
185+ const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2 ;
186+ for (int mask = mask_start; mask > 0 ; mask >>= 1 ) {
187+ tmp +=
188+ dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
189+ }
190+
191+ if (tid == 0 ) {
192+ #ifdef GGML_SYCL_F16
193+ dst[row] = tmp.x () + tmp.y ();
194+ #else
195+ dst[row] = tmp;
196+ #endif // GGML_SYCL_F16
197+ }
198+ }
199+
94200static void convert_mul_mat_vec_f16_sycl (const void *vx, const dfloat *y,
95201 float *dst, const int ncols,
96202 const int nrows,
@@ -760,6 +866,29 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
760866}
761867
762868
869+ static void dequantize_mul_mat_vec_q4_0_sycl_reorder (const void *vx, const dfloat *y,
870+ float *dst, const int ncols,
871+ const int nrows,
872+ dpct::queue_ptr stream) {
873+ GGML_ASSERT (ncols % GGML_SYCL_DMMV_X == 0 );
874+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1 ) / GGML_SYCL_MMV_Y;
875+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
876+ const sycl::range<3 > block_nums (1 , 1 , block_num_y);
877+ const sycl::range<3 > block_dims (1 , GGML_SYCL_MMV_Y, WARP_SIZE);
878+ {
879+ dpct::has_capability_or_fail (stream->get_device (),
880+ {sycl::aspect::fp16});
881+
882+ stream->parallel_for (
883+ sycl::nd_range<3 >(block_nums * block_dims, block_dims),
884+ [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
885+ dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
886+ vx, y, dst, ncols, nrows, item_ct1);
887+ });
888+ }
889+ }
890+
891+
763892static void dequantize_mul_mat_vec_q4_0_sycl (const void *vx, const dfloat *y,
764893 float *dst, const int ncols,
765894 const int nrows,
@@ -977,7 +1106,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
9771106
9781107 switch (src0->type ) {
9791108 case GGML_TYPE_Q4_0:
1109+ #if defined(GGML_SYCL_INTEL_TARGET)
1110+ dequantize_mul_mat_vec_q4_0_sycl_reorder (src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1111+ #else
9801112 dequantize_mul_mat_vec_q4_0_sycl (src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1113+ #endif
9811114 break ;
9821115 case GGML_TYPE_Q4_1:
9831116 dequantize_mul_mat_vec_q4_1_sycl (src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
@@ -1020,4 +1153,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
10201153 GGML_UNUSED (src1_ddq_i);
10211154 GGML_UNUSED (src1_ncols);
10221155 GGML_UNUSED (src1_padded_row_size);
1156+ GGML_UNUSED (ctx);
10231157}
0 commit comments