44#include < sys/types.h>
55
66#include < cstdint>
7- #include < sycl/aliases.hpp>
8- #include < sycl/ext/oneapi/experimental/root_group.hpp>
9- #include < sycl/functional.hpp>
10- #include < sycl/group_algorithm.hpp>
11- #include < sycl/nd_item.hpp>
7+ #include < tuple>
8+
129#include < sycl/sycl.hpp>
13- #include < sycl/vector.hpp>
1410
1511#include " builtins.hpp"
1612#include " cacheopts.hpp"
1713#include " ggml-quants.h"
1814#include " ggml-sycl/dpct/helper.hpp"
1915
20- #define sycl_print sycl::ext::oneapi::experimental::printf
16+ __attribute__ ((always_inline)) inline std::tuple<int, float> quantize_and_pack_input(
17+ const sycl::vec<float , 4 > & loaded_fp32_vals, int wi_id_in_sg, sycl::sub_group & sg) {
18+ float amax = 0 ;
19+ int packed_quants = 0 ;
20+ #pragma unroll(4)
21+ for (int i = 0 ; i < 4 ; i++) {
22+ amax = sycl::fmax (amax, sycl::fabs (loaded_fp32_vals[i]));
23+ }
24+
25+ float amax_value_to_contribute = wi_id_in_sg > 7 ? 0 : amax;
26+
27+ // first reduce for workitems 0 - 7;
28+ float abs_max_0_7 = sycl::reduce_over_group (sg, amax_value_to_contribute, sycl::maximum<float >());
29+
30+ amax_value_to_contribute = wi_id_in_sg < 7 ? 0 : amax;
31+
32+ float abs_max_8_15 = sycl::reduce_over_group (sg, amax_value_to_contribute, sycl::maximum<float >());
33+
34+ float amax_value = wi_id_in_sg > 7 ? abs_max_8_15 : abs_max_0_7;
35+
36+ float scale_value = amax_value == 0 ? 1 : amax_value / 127 ;
37+
38+ #pragma unroll(4)
39+ for (int i = 0 ; i < 4 ; i++) {
40+ int8_t quantized_value = sycl::round (loaded_fp32_vals[i] / scale_value);
41+ packed_quants = packed_quants | (int32_t ) ((uint8_t ) quantized_value) << (8 * i);
42+ }
43+ scale_value = amax_value == 0 ? 0 : scale_value;
44+
45+ return { packed_quants, scale_value };
46+ }
47+
2148//
2249/* *
2350 * @brief This function packs 4 q6_k quants in a 32 bit value.
@@ -32,21 +59,23 @@ __attribute__((always_inline)) inline int pack_q6_k(const short & low_bits, cons
3259// TODO: Reduce the number of brackets by checking the precedence order :)
3360#pragma unroll(4)
3461 for (uint8_t i = 0 ; i < 4 ; i++) {
35- uint16_t mask_low_bits = (0x000F ) << (4 * i);
36- uint8_t mask_high_bits = (0x3 ) << (2 * i);
37- uint8_t desired_low_bits = (low_bits & mask_low_bits) >> (4 * i);
38- uint8_t desired_high_bits = ((high_bits & mask_high_bits) >> (2 * i)) << 4 ;
39- int8_t full_value = static_cast <int8_t >(desired_high_bits | desired_low_bits);
40- full_value = sycl::sub_sat (full_value, (int8_t )32 );
62+ uint16_t mask_low_bits = (0x000F ) << (4 * i);
63+ uint8_t mask_high_bits = (0x3 ) << (2 * i);
64+ uint8_t desired_low_bits = (low_bits & mask_low_bits) >> (4 * i);
65+ uint8_t desired_high_bits = ((high_bits & mask_high_bits) >> (2 * i)) << 4 ;
66+ int8_t full_value = static_cast <int8_t >(desired_high_bits | desired_low_bits);
67+ full_value = sycl::sub_sat (full_value, (int8_t ) 32 );
4168 packed_q6_k |= (static_cast <uint32_t >(static_cast <uint8_t >(full_value)) << (8 * i));
4269 }
4370 return packed_q6_k;
4471}
4572
4673namespace sycl {
47- __attribute__ ((always_inline)) inline void q6k_tiled_gemv (
48- const int8_t * q6_k_l, const int8_t * q6_k_h, const int8_t * q8_1, float * result, const int8_t * q6_u8_bit_scales,
49- const sycl::half2 * q8_dm_scales, const sycl::half * q6_k_superblock_scale, int m, int k, const nd_item<1 > & it) {
74+ __attribute__ ((always_inline)) inline void q6k_tiled_gemv (const int8_t * q6_k_l, const int8_t * q6_k_h,
75+ const float * q8_1, float * result,
76+ const int8_t * q6_u8_bit_scales,
77+ const sycl::half * q6_k_superblock_scale, int m, int k,
78+ const nd_item<1 > & it) {
5079 // Performs a (m x k ) X (k x 1) GEMM
5180 // Each subgroup is responsible for 16 output elements.
5281
@@ -69,18 +98,17 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
6998 auto sg_id = it.get_group (0 ) * num_sgs_in_wg + sg.get_group_id ();
7099 auto wi_id_in_sg = sg.get_local_linear_id ();
71100
72- auto q6_k_l_width = ((k / 2 - 1 ) * sizeof (int8_t )); // as we have 2 4 bit values packed in an int8_t;
73- auto q6_k_h_width = ((k / 4 - 1 ) * sizeof (int8_t )); // as we have 4 2 bit values packed in an int8_t;
74- auto q8_1_width = (k - 1 ) * sizeof (int8_t );
75- auto result_width = (m - 1 ) * sizeof (float );
76- auto q6_u8_scale_width = ((k / QK_K) * 16 - 1 ) * sizeof (int8_t );
101+ auto q6_k_l_width = ((k / 2 - 1 ) * sizeof (int8_t )); // as we have 2 4 bit values packed in an int8_t;
102+ auto q6_k_h_width = ((k / 4 - 1 ) * sizeof (int8_t )); // as we have 4 2 bit values packed in an int8_t;
103+ auto result_width = (m - 1 ) * sizeof (float );
104+ auto q6_u8_scale_width = ((k / QK_K) * 16 - 1 ) * sizeof (int8_t );
77105 auto super_block_scale_width = (m - 1 ) * sizeof (sycl::half);
78106 const auto num_blocks_per_row = k / QK_K;
79107
80- const int tiles_required = m / tile_height;
81- sycl::vec<float , 16 > accumulator;
82- vector_types::char16 q6_u8_scales_vals;
83- sycl::half super_block_scale;
108+ const int tiles_required = m / tile_height;
109+ sycl::vec<float , 16 > accumulator;
110+ vector_types::char16 q6_u8_scales_vals;
111+ sycl::half super_block_scale;
84112
85113 for (; sg_id < tiles_required; sg_id += num_sgs_in_kernel) {
86114 auto h_coord = sg_id * tile_height;
@@ -99,11 +127,11 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
99127 auto super_block_scale_loaded = __builtin_IB_subgroup_block_read_flat_u16_m1k16v1 (
100128 (intptr_t ) q6_k_superblock_scale, super_block_scale_width, num_blocks_per_row - 1 ,
101129 super_block_scale_width, vector_types::uint2{ (uint) (h_coord), (uint) i });
102- super_block_scale = *reinterpret_cast <sycl::half*>(&super_block_scale_loaded);
103-
130+ super_block_scale = *reinterpret_cast <sycl::half *>(&super_block_scale_loaded);
131+
104132 auto element_width_offset = i * QK_K;
105- auto q6_l_w_coord_start = i * (QK_K / 2 );
106- auto q6_h_w_coord_start = i * (QK_K / 4 );
133+ auto q6_l_w_coord_start = i * (QK_K / 2 );
134+ auto q6_h_w_coord_start = i * (QK_K / 4 );
107135
108136# pragma unroll(4)
109137 for (int j = 0 ; j < QK_K; j += tile_width) {
@@ -115,21 +143,22 @@ __attribute__((always_inline)) inline void q6k_tiled_gemv(
115143 (intptr_t ) (q6_k_h), q6_k_h_width, m - 1 , q6_k_h_width,
116144 vector_types::uint2{ (uint) (q6_h_w_coord_start + j / 4 ), (uint) h_coord });
117145
118- int packed_q8_1_vals = __builtin_IB_subgroup_block_read_flat_u8_m1k64v1 (
119- (intptr_t ) (q8_1), q8_1_width, 0 , q8_1_width,
120- vector_types::uint2{ (uint) (element_width_offset + j), (uint) 0 });
146+ auto loaded_fp32_vals = *reinterpret_cast <const sycl::vec<float , 4 > *>(q8_1 + element_width_offset + j);
147+ // int packed_q8_1_vals = __builtin_IB_subgroup_block_read_flat_u8_m1k64v1(
148+ // (intptr_t) (q8_1), q8_1_width, 0, q8_1_width,
149+ // vector_types::uint2{ (uint) (element_width_offset + j), (uint) 0 });
121150
122- sycl::half2 q8_dm_val =
123- q8_dm_scales[element_width_offset / QK8_1 + j / QK8_1 + (wi_id_in_sg * 4 ) / QK8_1];
151+ auto [packed_q8_1_vals, q8_scale_fp32] = quantize_and_pack_input (loaded_fp32_vals, wi_id_in_sg, sg);
124152
125153# pragma unroll(16)
126154 for (uint8_t l = 0 ; l < 16 ; l++) {
127155 int packed_q6_k_vals = pack_q6_k (q6_low_bits[l], q6_high_bits[l]);
128156 int dp4a_val = __builtin_IB_dp4a_ss (0 , packed_q6_k_vals, packed_q8_1_vals, dp4a_with_saturation);
129157 sycl::half q6_super_block_value = sycl::select_from_group (sg, super_block_scale, l);
130- int8_t q6_block_scale_val = sycl::select_from_group (sg, q6_u8_scales_vals[l], j / 16 + (wi_id_in_sg ) / 4 );
158+ int8_t q6_block_scale_val =
159+ sycl::select_from_group (sg, q6_u8_scales_vals[l], j / 16 + (wi_id_in_sg) / 4 );
131160 accumulator[l] += dp4a_val * static_cast <float >(q6_super_block_value) *
132- static_cast <float >(q6_block_scale_val) * static_cast < float >(q8_dm_val[ 0 ]) ;
161+ static_cast <float >(q6_block_scale_val) * q8_scale_fp32 ;
133162 }
134163 }
135164 }
0 commit comments