@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
285285 }
286286
287287 __dpct_inline__ float operator ()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288- const int8_t * q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int /* nblocks */ ) {
288+ const int8_t * q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */ ) {
289289 const uint8_t * bq4_0 = static_cast <const uint8_t *>(vbq) + ibx_offset;
290290 const ggml_half d = *(reinterpret_cast <const ggml_half *>(static_cast <const uint8_t *>(vbq) + d_offset));
291291 int v[q4_0_traits::vdr_mmvq];
@@ -299,7 +299,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
299299 u[2 * i + 1 ] = get_int_from_int8_aligned (q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300300 }
301301
302- return vec_dot_q4_0_q8_1_impl (v, u, d, q8_1_ds);
302+ return vec_dot_q4_0_q8_1_impl (v, u, d, * q8_1_ds);
303303 };
304304};
305305
@@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
347347 using q4_k_traits = typename q4_k_block::traits;
348348
349349 float operator ()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350- const int8_t * q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int nblocks) {
350+ const int8_t * q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
351351 const int ib = ibx_offset / (QK_K / 2 );
352352
353353 const uint8_t * base = static_cast <const uint8_t *>(vbq);
@@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
360360 const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2 ) % 4 ));
361361 const uint16_t * scales = (const uint16_t *) scs;
362362
363- return vec_dot_q4_K_q8_1_common (q4, scales, *dms, bq8_1, iqs);
363+ int v[2 ];
364+ int u[2 * QR4_K];
365+ float d8[QR4_K];
366+
367+ v[0 ] = q4[0 ];
368+ v[1 ] = q4[4 ];
369+
370+ uint16_t aux[2 ];
371+ const int j = (QR4_K * ((iqs / 2 ) / (QI8_1 / 2 ))) / 2 ;
372+ if (j < 2 ) {
373+ aux[0 ] = scales[j + 0 ] & 0x3f3f ;
374+ aux[1 ] = scales[j + 2 ] & 0x3f3f ;
375+ } else {
376+ aux[0 ] = ((scales[j + 2 ] >> 0 ) & 0x0f0f ) | ((scales[j - 2 ] & 0xc0c0 ) >> 2 );
377+ aux[1 ] = ((scales[j + 2 ] >> 4 ) & 0x0f0f ) | ((scales[j - 0 ] & 0xc0c0 ) >> 2 );
378+ }
379+
380+ const uint8_t * sc = (const uint8_t *) aux;
381+ const uint8_t * m = sc + 2 ;
382+
383+ for (int i = 0 ; i < QR4_K; ++i) {
384+ const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
385+ sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
386+
387+ d8[i] = ds_values[0 ];
388+
389+ const int * q8 = (const int *) quant_base_ptr + ((iqs / 2 ) % 4 );
390+ u[2 * i + 0 ] = q8[0 ];
391+ u[2 * i + 1 ] = q8[4 ];
392+ }
393+
394+ return vec_dot_q4_K_q8_1_impl_vmmq (v, u, sc, m, *dms, d8);
364395 }
365396};
366397
0 commit comments