4343#include " ggml-sycl/gemm.hpp"
4444#include " ggml-sycl/sycl_hw.hpp"
4545#include " ggml-sycl/getrows.hpp"
46+ #include " ggml-sycl/quantize.hpp"
4647#include " ggml.h"
4748
4849static bool g_sycl_loaded = false ;
@@ -1374,120 +1375,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
13741375
13751376
13761377
1377- template <int QUANT_BLOCK_TILE>
1378- static void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
1379- const sycl::nd_item<3 > &item_ct1) {
1380- const int ix = (item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
1381- item_ct1.get_local_id (2 )) * QUANT_BLOCK_TILE;
1382-
1383- if (ix >= kx_padded) {
1384- return ;
1385- }
1386-
1387- const int iy = item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
1388- item_ct1.get_local_id (1 );
1389-
1390- const int i_padded = iy*kx_padded + ix;
1391-
1392- block_q8_1 * y = (block_q8_1 *) vy;
1393-
1394- const int ib = i_padded / QK8_1; // block index
1395- const int iqs = i_padded % QK8_1; // quant index
1396- typedef sycl::vec<float , QUANT_BLOCK_TILE> TC;
1397- typedef sycl::vec<int8_t , QUANT_BLOCK_TILE> TQ;
1398- TC zeros;
1399- TQ qzeros;
1400- #pragma unroll
1401- for (int i = 0 ; i < QUANT_BLOCK_TILE; i++)
1402- {
1403- zeros[i] = 0 .f ;
1404- qzeros[i] = 0 ;
1405- }
1406- const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
1407- float sum = xi[0 ];
1408- float amax = sycl::fabs (xi[0 ]);
1409- #pragma unroll
1410- for (int i = 1 ; i < QUANT_BLOCK_TILE; i++)
1411- {
1412- sum += xi[i];
1413- amax = sycl::fmax (sycl::fabs (xi[i]), amax);
1414- }
1415- sum = warp_reduce_sum (sum, item_ct1);
1416- amax = warp_reduce_max (amax, item_ct1);
1417-
1418- const float d = amax / 127 ;
1419- TQ q = qzeros;
1420- if (amax != 0 .0f )
1421- {
1422- #pragma unroll
1423- for (int i = 0 ; i < QUANT_BLOCK_TILE; i++) {
1424- q[i] = sycl::round (xi[i] / d);
1425- }
1426- }
1427-
1428- *(TQ *)&y[ib].qs [iqs] = q;
1429-
1430- if (iqs > 0 ) {
1431- return ;
1432- }
1433-
1434- reinterpret_cast <sycl::half &>(y[ib].ds .x ()) = d;
1435- reinterpret_cast <sycl::half &>(y[ib].ds .y ()) = sum;
1436- }
1437-
1438- template <int ElementsPerWI>
1439- static __dpct_inline__ void quantize_and_reorder_q8_1 (const float * __restrict__ x, void * reordered_q8_tensor,
1440- const int kx, const int kx_padded, const sycl::nd_item<1 > & it) {
1441- /*
1442- Quantizes and reorders the resultant q8 tensor in a per row fashion
1443- Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1444- */
1445-
1446- auto subgroup_id = it.get_group (0 );
1447- auto wi_id = it.get_local_id (0 );
1448-
1449- const int num_blocks_per_row = kx / QK8_1;
1450- auto row = subgroup_id / num_blocks_per_row;
1451- auto col = subgroup_id % num_blocks_per_row;
1452-
1453- auto row_offset = row * (kx_padded / QK8_1) * sizeof (block_q8_1);
1454- auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1455-
1456- auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1457- auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof (sycl::half2));
1458-
1459- sycl::vec<float , ElementsPerWI> wi_f32_vals;
1460- sycl::vec<int8_t , ElementsPerWI> quantized_values;
1461-
1462- auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1463- wi_f32_vals = *reinterpret_cast <const sycl::vec<float , ElementsPerWI> *>(x + float_ptr_offset);
1464-
1465- float sum = 0 .0f ;
1466- float amax = 0 .0f ;
1467-
1468- #pragma unroll(ElementsPerWI)
1469- for (int i = 0 ; i < ElementsPerWI; i++) {
1470- sum += wi_f32_vals[i];
1471- amax = sycl::fmax (amax, sycl::fabs (wi_f32_vals[i]));
1472- quantized_values[i] = 0 ;
1473- }
1474- sum = sycl::reduce_over_group (it.get_group (), sum, sycl::plus<float >());
1475- amax = sycl::reduce_over_group (it.get_group (), amax, sycl::maximum<float >());
1476- float d = amax == 0 ? 1 : amax / 127 ;
1477-
1478- #pragma unroll(ElementsPerWI)
1479- for (int i = 0 ; i < ElementsPerWI; i++) {
1480- quantized_values[i] = sycl::round (wi_f32_vals[i] / d);
1481- }
1482-
1483- d = amax == 0 ? 0 : d;
1484-
1485- *reinterpret_cast <sycl::vec<int8_t , ElementsPerWI> *>(quant_ptr) = quantized_values;
1486- if (wi_id == 0 ) {
1487- *ds_ptr = sycl::half2 (sycl::half (d), sycl::half (sum));
1488- }
1489- }
1490-
14911378static void mul_mat_p021_f16_f32 (
14921379 const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
14931380 const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1772,32 +1659,6 @@ static void pool2d_nchw_kernel(
17721659 o_ptr[cur_oh * ow + cur_ow] = res;
17731660}
17741661
1775- static void quantize_row_q8_1_sycl (const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1776- bool reorder_q8_tensor, queue_ptr stream) {
1777- if (reorder_q8_tensor) {
1778- auto local_range = std::size_t (WARP_SIZE);
1779- auto num_quant_blocks = ky * (kx / QK8_1);
1780- auto global_range = num_quant_blocks * local_range;
1781- stream->parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }),
1782- [=](sycl::nd_item<1 > it) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
1783- quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1784- });
1785- } else {
1786- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1 ) / SYCL_QUANTIZE_BLOCK_SIZE;
1787- const sycl::range<3 > num_blocks (1 , ky, block_num_x);
1788- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1789- static_assert (QK8_1 % WARP_SIZE == 0 );
1790- const sycl::range<3 > block_size (1 , 1 , SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1791- {
1792- dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
1793-
1794- stream->parallel_for (sycl::nd_range<3 >(num_blocks * block_size, block_size),
1795- [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
1796- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1797- });
1798- }
1799- }
1800- }
18011662
18021663static void ggml_mul_mat_p021_f16_f32_sycl (const void *vx, const float *y,
18031664 float *dst, const int ncols_x,
@@ -2380,10 +2241,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
23802241 peer_access_enabled = enable_peer_access;
23812242}
23822243
2244+ template <template <int > typename quantize_f>
23832245static void ggml_sycl_op_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
23842246 const ggml_tensor *src1, ggml_tensor *dst,
2385- ggml_sycl_op_mul_mat_t op,
2386- const bool convert_src1_to_q8_1) try {
2247+ ggml_sycl_op_mul_mat_t op) try {
23872248
23882249 GGML_TENSOR_LOCALS (int64_t , ne0, src0, ne);
23892250
@@ -2478,6 +2339,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24782339 }
24792340 }
24802341
2342+ constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2343+ no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
24812344 for (int i = 0 ; i < ggml_sycl_info ().device_count ; ++i) {
24822345 if ((!split && i != ctx.device ) || dev[i].row_low == dev[i].row_high ) {
24832346 continue ;
@@ -2503,20 +2366,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25032366 dev[i].src1_ddf = dev[i].src1_ddf_alloc .alloc (ctx.pool (i), ggml_nelements (src1));
25042367 }
25052368
2506- if (convert_src1_to_q8_1 ) {
2369+ if constexpr (quantize_enabled ) {
25072370 dev[i].src1_ddq = dev[i].src1_ddq_alloc .alloc (ctx.pool (i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
25082371
25092372 if (src1_on_device && src1_is_contiguous) {
2510- bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra )->optimized_feature .reorder ;
25112373 scope_op_debug_print scope_dbg_print (__func__, " /quantize_row_q8_1_sycl" , dst,
25122374 /* num_src=*/ 2 , " : converting src1 to Q8_1" );
2513- quantize_row_q8_1_sycl (dev[i]. src1_ddf , dev[i]. src1_ddq , ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2514- /*
2515- DPCT1010:90: SYCL uses exceptions to report errors and does not
2516- use the error codes. The call was replaced with 0. You need to
2517- rewrite this code.
2518- */
2519- SYCL_CHECK ( 0 );
2375+ try {
2376+ quantize_row_q8_1_sycl<quantize_f>(dev[i]. src1_ddf , dev[i]. src1_ddq , ne10, nrows1, src1_padded_col_size, stream);
2377+ } catch (sycl::exception const &exc) {
2378+ std::cerr << " Quantize_row_q8_1_sycl error " << exc. what () << " Exception caught at file: " << __FILE__
2379+ << " , line: " << __LINE__ << std::endl;
2380+ std::exit ( 1 );
2381+ }
25202382 }
25212383 }
25222384
@@ -2590,7 +2452,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25902452 // copy src0, src1 to device if necessary
25912453 if (src1_is_contiguous) {
25922454 if (i != ctx.device ) {
2593- if (convert_src1_to_q8_1 ) {
2455+ if constexpr (quantize_enabled ) {
25942456 char * src1_ddq_i_source = dev[ctx.device ].src1_ddq + src1_ddq_i_offset;
25952457 SYCL_CHECK (CHECK_TRY_ERROR (stream->memcpy (
25962458 src1_ddq_i, src1_ddq_i_source,
@@ -2613,16 +2475,18 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
26132475 GGML_ABORT (" fatal error" );
26142476 }
26152477
2616- if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2617- scope_op_debug_print scope_dbg_print (__func__, " /quantize_row_q8_1_sycl" , dst,
2618- /* num_src=*/ 2 , " : converting src1 to Q8_1" );
2619- quantize_row_q8_1_sycl (src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false , stream);
2620- /*
2621- DPCT1010:92: SYCL uses exceptions to report errors and does
2622- not use the error codes. The call was replaced with 0. You
2623- need to rewrite this code.
2624- */
2625- SYCL_CHECK (0 );
2478+ if constexpr (quantize_enabled) {
2479+ if (!src1_is_contiguous) {
2480+ scope_op_debug_print scope_dbg_print (__func__, " /quantize_row_q8_1_sycl" , dst,
2481+ /* num_src=*/ 2 , " : converting src1 to Q8_1" );
2482+ try {
2483+ quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2484+ } catch (sycl::exception const &exc) {
2485+ std::cerr << " Quantize_row_q8_1_sycl error" << exc.what () << " Exception caught at file:" << __FILE__
2486+ << " , line:" << __LINE__ << std::endl;
2487+ std::exit (1 );
2488+ }
2489+ }
26262490 }
26272491
26282492 if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0 ) {
@@ -3277,19 +3141,20 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
32773141 // KQ + KQV multi-batch
32783142 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
32793143 } else if (use_dequantize_mul_mat_vec) {
3280- constexpr bool convert_src1_to_q8_1 = false ;
32813144 opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3282- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1 );
3145+ ggml_sycl_op_mul_mat<no_quantize_q8_1> (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
32833146 } else if (use_mul_mat_vec_q) {
3284- constexpr bool convert_src1_to_q8_1 = true ;
32853147 opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3286- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3148+ ggml_tensor_extra_gpu * extra = static_cast <ggml_tensor_extra_gpu *>(src0->extra );
3149+ if (extra && extra->optimized_feature .reorder ) {
3150+ ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3151+ } else {
3152+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3153+ }
32873154 } else if (use_mul_mat_q) {
3288- constexpr bool convert_src1_to_q8_1 = true ;
3289- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3155+ ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
32903156 } else {
3291- constexpr bool convert_src1_to_q8_1 = false ;
3292- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3157+ ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
32933158 }
32943159}
32953160
0 commit comments