Skip to content

Commit 472c495

Browse files
committed
fuse q8_1 quantization in q6_k_tiled_gemv
1 parent ae5270e commit 472c495

File tree

3 files changed

+74
-46
lines changed

3 files changed

+74
-46
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3371,7 +3371,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
33713371
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
33723372
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
33733373
} else if (use_mul_mat_vec_q) {
3374-
constexpr bool convert_src1_to_q8_1 = true;
3374+
bool convert_src1_to_q8_1 = ctx.opt_feature.can_use_intel_builtins ? false : true;
33753375
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
33763376
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
33773377
} else if (use_mul_mat_q) {

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
#include "quants.hpp"
1515
#include "vecdotq.hpp"
1616

17-
static void q6_k_tiled_gemv(const int8_t * q6_k_low, const int8_t * q6_k_high, const int8_t * q8_1_input,
18-
const int8_t * q6_scales, const sycl::half * q6_k_superblock_scales,
19-
const sycl::half2 * q8_scales, float * output, std::size_t m, std::size_t k,
17+
static void q6_k_tiled_gemv(const int8_t * q6_k_low, const int8_t * q6_k_high, const float * src1_f32,
18+
const int8_t * q6_scales, const sycl::half * q6_k_superblock_scales,
19+
float * output, std::size_t m, std::size_t k,
2020
dpct::queue_ptr stream) {
2121
constexpr int SubgroupSize = 16;
2222
constexpr int tile_height = 16;
@@ -28,8 +28,8 @@ static void q6_k_tiled_gemv(const int8_t * q6_k_low, const int8_t * q6_k_high, c
2828
sycl_launch(stream, [&](sycl::handler & cgh) {
2929
sycl_parallel_for(cgh, sycl::nd_range<1>({ global_range }, { local_range }),
3030
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(SubgroupSize)]] {
31-
[[clang::always_inline]] sycl::q6k_tiled_gemv(q6_k_low, q6_k_high, q8_1_input, output,
32-
q6_scales, q8_scales,
31+
[[clang::always_inline]] sycl::q6k_tiled_gemv(q6_k_low, q6_k_high, src1_f32, output,
32+
q6_scales,
3333
q6_k_superblock_scales, m, k, it);
3434
});
3535
});
@@ -1005,6 +1005,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10051005
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
10061006
const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
10071007
float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
1008+
const float* src1_ddfi_row = src1_ddf_i + i * src1_padded_col_size;
10081009
switch (src0->type) {
10091010
case GGML_TYPE_Q4_0:
10101011
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
@@ -1060,9 +1061,8 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
10601061
auto q6_h_ptr = q6_l_ptr + (QK_K / 2) * num_q6_blocks;
10611062
auto scales_u8_q6_k = q6_h_ptr + (QK_K / 4) * num_q6_blocks;
10621063
auto scales_q6_k_superblock = (sycl::half*)(scales_u8_q6_k + num_q6_blocks * (QK_K / 16));
1063-
auto q8_1_input = (int8_t *) src1_ddq_i_bs;
1064-
auto q8_1_input_scales = (sycl::half2 *) (q8_1_input + k);
1065-
q6_k_tiled_gemv(q6_l_ptr, q6_h_ptr, q8_1_input, scales_u8_q6_k, scales_q6_k_superblock, q8_1_input_scales, dst_dd_i_bs, m, k,
1064+
auto src_1_f32 = src1_ddfi_row;
1065+
q6_k_tiled_gemv(q6_l_ptr, q6_h_ptr, src_1_f32, scales_u8_q6_k, scales_q6_k_superblock, dst_dd_i_bs, m, k,
10661066
stream);
10671067
} else {
10681068
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
@@ -1107,6 +1107,5 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
11071107
}
11081108
GGML_UNUSED(src1);
11091109
GGML_UNUSED(dst);
1110-
GGML_UNUSED(src1_ddf_i);
11111110
GGML_UNUSED(ctx);
11121111
}

ggml/src/ggml-sycl/q6_k_tiled_gemv.hpp

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,47 @@
44
#include <sys/types.h>
55

66
#include <cstdint>
7-
#include <sycl/aliases.hpp>
8-
#include <sycl/ext/oneapi/experimental/root_group.hpp>
9-
#include <sycl/functional.hpp>
10-
#include <sycl/group_algorithm.hpp>
11-
#include <sycl/nd_item.hpp>
7+
#include <tuple>
8+
129
#include <sycl/sycl.hpp>
13-
#include <sycl/vector.hpp>
1410

1511
#include "builtins.hpp"
1612
#include "cacheopts.hpp"
1713
#include "ggml-quants.h"
1814
#include "ggml-sycl/dpct/helper.hpp"
1915

20-
#define sycl_print sycl::ext::oneapi::experimental::printf
16+
__attribute__((always_inline)) inline std::tuple<int, float> quantize_and_pack_input(
17+
const sycl::vec<float, 4> & loaded_fp32_vals, int wi_id_in_sg, sycl::sub_group & sg) {
18+
float amax = 0;
19+
int packed_quants = 0;
20+
#pragma unroll(4)
21+
for (int i = 0; i < 4; i++) {
22+
amax = sycl::fmax(amax, sycl::fabs(loaded_fp32_vals[i]));
23+
}
24+
25+
float amax_value_to_contribute = wi_id_in_sg > 7 ? 0 : amax;
26+
27+
// first reduce for workitems 0 - 7;
28+
float abs_max_0_7 = sycl::reduce_over_group(sg, amax_value_to_contribute, sycl::maximum<float>());
29+
30+
amax_value_to_contribute = wi_id_in_sg < 7 ? 0 : amax;
31+
32+
float abs_max_8_15 = sycl::reduce_over_group(sg, amax_value_to_contribute, sycl::maximum<float>());
33+
34+
float amax_value = wi_id_in_sg > 7 ? abs_max_8_15 : abs_max_0_7;
35+
36+
float scale_value = amax_value == 0 ? 1 : amax_value / 127;
37+
38+
#pragma unroll(4)
39+
for (int i = 0; i < 4; i++) {
40+
int8_t quantized_value = sycl::round(loaded_fp32_vals[i] / scale_value);
41+
packed_quants = packed_quants | (int32_t) ((uint8_t) quantized_value) << (8 * i);
42+
}
43+
scale_value = amax_value == 0 ? 0 : scale_value;
44+
45+
return { packed_quants, scale_value };
46+
}
47+
2148
//
2249
/**
2350
* @brief This function packs 4 q6_k quants in a 32 bit value.
@@ -32,21 +59,23 @@ __attribute__((always_inline)) inline int pack_q6_k(const short & low_bits, cons
3259
// TODO: Reduce the number of brackets by checking the precedence order :)
3360
#pragma unroll(4)
3461
for (uint8_t i = 0; i < 4; i++) {
35-
uint16_t mask_low_bits = (0x000F) << (4 * i);
36-
uint8_t mask_high_bits = (0x3) << (2 * i);
37-
uint8_t desired_low_bits = (low_bits & mask_low_bits) >> (4 * i);
38-
uint8_t desired_high_bits = ((high_bits & mask_high_bits) >> (2 * i)) << 4;
39-
int8_t full_value = static_cast<int8_t>(desired_high_bits | desired_low_bits);
40-
full_value = sycl::sub_sat(full_value, (int8_t)32);
62+
uint16_t mask_low_bits = (0x000F) << (4 * i);
63+
uint8_t mask_high_bits = (0x3) << (2 * i);
64+
uint8_t desired_low_bits = (low_bits & mask_low_bits) >> (4 * i);
65+
uint8_t desired_high_bits = ((high_bits & mask_high_bits) >> (2 * i)) << 4;
66+
int8_t full_value = static_cast<int8_t>(desired_high_bits | desired_low_bits);
67+
full_value = sycl::sub_sat(full_value, (int8_t) 32);
4168
packed_q6_k |= (static_cast<uint32_t>(static_cast<uint8_t>(full_value)) << (8 * i));
4269
}
4370
return packed_q6_k;
4471
}
4572

4673
namespace sycl {
47-
__attribute__((always_inline)) inline void q6k_tiled_gemv(
48-
const int8_t * q6_k_l, const int8_t * q6_k_h, const int8_t * q8_1, float * result, const int8_t * q6_u8_bit_scales,
49-
const sycl::half2 * q8_dm_scales, const sycl::half * q6_k_superblock_scale, int m, int k, const nd_item<1> & it) {
74+
__attribute__((always_inline)) inline void q6k_tiled_gemv(const int8_t * q6_k_l, const int8_t * q6_k_h,
75+
const float * q8_1, float * result,
76+
const int8_t * q6_u8_bit_scales,
77+
const sycl::half * q6_k_superblock_scale, int m, int k,
78+
const nd_item<1> & it) {
5079
// Performs a (m x k ) X (k x 1) GEMM
5180
// Each subgroup is responsible for 16 output elements.
5281

@@ -69,18 +98,17 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
6998
auto sg_id = it.get_group(0) * num_sgs_in_wg + sg.get_group_id();
7099
auto wi_id_in_sg = sg.get_local_linear_id();
71100

72-
auto q6_k_l_width = ((k / 2 - 1) * sizeof(int8_t)); // as we have 2 4 bit values packed in an int8_t;
73-
auto q6_k_h_width = ((k / 4 - 1) * sizeof(int8_t)); // as we have 4 2 bit values packed in an int8_t;
74-
auto q8_1_width = (k - 1) * sizeof(int8_t);
75-
auto result_width = (m - 1) * sizeof(float);
76-
auto q6_u8_scale_width = ((k / QK_K) * 16 - 1) * sizeof(int8_t);
101+
auto q6_k_l_width = ((k / 2 - 1) * sizeof(int8_t)); // as we have 2 4 bit values packed in an int8_t;
102+
auto q6_k_h_width = ((k / 4 - 1) * sizeof(int8_t)); // as we have 4 2 bit values packed in an int8_t;
103+
auto result_width = (m - 1) * sizeof(float);
104+
auto q6_u8_scale_width = ((k / QK_K) * 16 - 1) * sizeof(int8_t);
77105
auto super_block_scale_width = (m - 1) * sizeof(sycl::half);
78106
const auto num_blocks_per_row = k / QK_K;
79107

80-
const int tiles_required = m / tile_height;
81-
sycl::vec<float, 16> accumulator;
82-
vector_types::char16 q6_u8_scales_vals;
83-
sycl::half super_block_scale;
108+
const int tiles_required = m / tile_height;
109+
sycl::vec<float, 16> accumulator;
110+
vector_types::char16 q6_u8_scales_vals;
111+
sycl::half super_block_scale;
84112

85113
for (; sg_id < tiles_required; sg_id += num_sgs_in_kernel) {
86114
auto h_coord = sg_id * tile_height;
@@ -99,11 +127,11 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
99127
auto super_block_scale_loaded = __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(
100128
(intptr_t) q6_k_superblock_scale, super_block_scale_width, num_blocks_per_row - 1,
101129
super_block_scale_width, vector_types::uint2{ (uint) (h_coord), (uint) i });
102-
super_block_scale = *reinterpret_cast<sycl::half*>(&super_block_scale_loaded);
103-
130+
super_block_scale = *reinterpret_cast<sycl::half *>(&super_block_scale_loaded);
131+
104132
auto element_width_offset = i * QK_K;
105-
auto q6_l_w_coord_start = i * (QK_K / 2);
106-
auto q6_h_w_coord_start = i * (QK_K / 4);
133+
auto q6_l_w_coord_start = i * (QK_K / 2);
134+
auto q6_h_w_coord_start = i * (QK_K / 4);
107135

108136
# pragma unroll(4)
109137
for (int j = 0; j < QK_K; j += tile_width) {
@@ -115,21 +143,22 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
115143
(intptr_t) (q6_k_h), q6_k_h_width, m - 1, q6_k_h_width,
116144
vector_types::uint2{ (uint) (q6_h_w_coord_start + j / 4), (uint) h_coord });
117145

118-
int packed_q8_1_vals = __builtin_IB_subgroup_block_read_flat_u8_m1k64v1(
119-
(intptr_t) (q8_1), q8_1_width, 0, q8_1_width,
120-
vector_types::uint2{ (uint) (element_width_offset + j), (uint) 0 });
146+
auto loaded_fp32_vals = *reinterpret_cast<const sycl::vec<float, 4> *>(q8_1 + element_width_offset + j);
147+
// int packed_q8_1_vals = __builtin_IB_subgroup_block_read_flat_u8_m1k64v1(
148+
// (intptr_t) (q8_1), q8_1_width, 0, q8_1_width,
149+
// vector_types::uint2{ (uint) (element_width_offset + j), (uint) 0 });
121150

122-
sycl::half2 q8_dm_val =
123-
q8_dm_scales[element_width_offset / QK8_1 + j / QK8_1 + (wi_id_in_sg * 4) / QK8_1];
151+
auto [packed_q8_1_vals, q8_scale_fp32] = quantize_and_pack_input(loaded_fp32_vals, wi_id_in_sg, sg);
124152

125153
# pragma unroll(16)
126154
for (uint8_t l = 0; l < 16; l++) {
127155
int packed_q6_k_vals = pack_q6_k(q6_low_bits[l], q6_high_bits[l]);
128156
int dp4a_val = __builtin_IB_dp4a_ss(0, packed_q6_k_vals, packed_q8_1_vals, dp4a_with_saturation);
129157
sycl::half q6_super_block_value = sycl::select_from_group(sg, super_block_scale, l);
130-
int8_t q6_block_scale_val = sycl::select_from_group(sg, q6_u8_scales_vals[l], j / 16 + (wi_id_in_sg ) / 4);
158+
int8_t q6_block_scale_val =
159+
sycl::select_from_group(sg, q6_u8_scales_vals[l], j / 16 + (wi_id_in_sg) / 4);
131160
accumulator[l] += dp4a_val * static_cast<float>(q6_super_block_value) *
132-
static_cast<float>(q6_block_scale_val) * static_cast<float>(q8_dm_val[0]);
161+
static_cast<float>(q6_block_scale_val) * q8_scale_fp32;
133162
}
134163
}
135164
}

0 commit comments

Comments
 (0)