Skip to content

Commit acd80ec

Browse files
committed
working q8 reorder commit
1 parent 6096ff8 commit acd80ec

File tree

3 files changed

+38
-18
lines changed

3 files changed

+38
-18
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,19 +1418,8 @@ template <int ElementsPerWI>
14181418
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
14191419
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
14201420
/*
1421-
quantize and reorders the resultant q8 tensor in a per row fashion
1422-
Each sub-group calculates one quant block
1423-
work_group_size = sub_group_size;
1424-
1425-
|------------------------------ Matrix Pitch -------------------------|
1426-
|------- Matrix Width --------|
1427-
q_00 q_01 q_02 ..... q_0n-1 q_n ds00 ds01 ... ds0n/32 ... padding ... |
1428-
. . |
1429-
. . |
1430-
. . Matrix Height
1431-
. . |
1432-
. . |
1433-
q_n0 q_n1 q_n2 ..... q_nn-1 q_n dsn0 dsn1 ... dsnn/32 ... padding ... |
1421+
Quantizes and reorders the resultant q8 tensor in a per row fashion
1422+
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
14341423
*/
14351424

14361425
auto subgroup_id = it.get_group(0);

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
4040
// Y block index that aligns with ibx
4141
const int iby = i * block_type::block_to_q8_1_ratio();
4242
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
43-
sycl::half2 q8_1_ds_ptr = *(const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
43+
const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2));
4444

4545
#pragma unroll
4646
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {

ggml/src/ggml-sycl/vecdotq.hpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
285285
}
286286

287287
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
288-
const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int /* nblocks */) {
288+
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
289289
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
290290
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
291291
int v[q4_0_traits::vdr_mmvq];
@@ -299,7 +299,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
299299
u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi);
300300
}
301301

302-
return vec_dot_q4_0_q8_1_impl(v, u, d, q8_1_ds);
302+
return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds);
303303
};
304304
};
305305

@@ -347,7 +347,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
347347
using q4_k_traits = typename q4_k_block::traits;
348348

349349
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
350-
const int8_t* q8_1_quant_ptr, const sycl::half2& q8_1_ds, const int & iqs, int nblocks) {
350+
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
351351
const int ib = ibx_offset / (QK_K / 2);
352352

353353
const uint8_t * base = static_cast<const uint8_t *>(vbq);
@@ -360,7 +360,38 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
360360
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
361361
const uint16_t * scales = (const uint16_t *) scs;
362362

363-
return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
363+
int v[2];
364+
int u[2 * QR4_K];
365+
float d8[QR4_K];
366+
367+
v[0] = q4[0];
368+
v[1] = q4[4];
369+
370+
uint16_t aux[2];
371+
const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
372+
if (j < 2) {
373+
aux[0] = scales[j + 0] & 0x3f3f;
374+
aux[1] = scales[j + 2] & 0x3f3f;
375+
} else {
376+
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
377+
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
378+
}
379+
380+
const uint8_t * sc = (const uint8_t *) aux;
381+
const uint8_t * m = sc + 2;
382+
383+
for (int i = 0; i < QR4_K; ++i) {
384+
const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1;
385+
sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i);
386+
387+
d8[i] = ds_values[0];
388+
389+
const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4);
390+
u[2 * i + 0] = q8[0];
391+
u[2 * i + 1] = q8[4];
392+
}
393+
394+
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8);
364395
}
365396
};
366397

0 commit comments

Comments
 (0)