Skip to content

Commit 78e232a

Browse files
committed
opt performance by reorder for Intel GPU
1 parent a7b8ce2 commit 78e232a

File tree

6 files changed

+382
-7
lines changed

6 files changed

+382
-7
lines changed

examples/sycl/run-llama2.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# MIT license
44
# Copyright (C) 2024 Intel Corporation
55
# SPDX-License-Identifier: MIT
6-
6+
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
77
source /opt/intel/oneapi/setvars.sh
88

99
#export GGML_SYCL_DEBUG=1
@@ -13,7 +13,7 @@ source /opt/intel/oneapi/setvars.sh
1313
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
1414
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
1515
NGL=33
16-
CONEXT=8192
16+
CONEXT=4096
1717

1818
if [ $# -gt 0 ]; then
1919
GGML_SYCL_DEVICE=$1

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
1+
message(STATUS "GML_SYCL_TARGET=${GGML_SYCL_TARGET}")
2+
13
if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
24
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
35
endif()
46

7+
if (GGML_SYCL_TARGET STREQUAL "INTEL")
8+
add_compile_definitions(GGML_SYCL_INTEL_TARGET)
9+
endif()
10+
11+
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
12+
add_compile_definitions(GGML_SYCL_NVIDIA_TARGET)
13+
endif()
14+
15+
if (GGML_SYCL_TARGET STREQUAL "AMD")
16+
add_compile_definitions(GGML_SYCL_AMD_TARGET)
17+
endif()
18+
19+
520
check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
621

722
if (DEFINED ENV{ONEAPI_ROOT})

ggml/src/ggml-sycl/convert.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,26 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
125125
}
126126
}
127127

128+
129+
template <typename dst_t>
130+
static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
131+
dpct::queue_ptr stream) {
132+
133+
dpct::has_capability_or_fail(stream->get_device(),
134+
{sycl::aspect::fp16});
135+
136+
int constexpr WARP_K = WARP_SIZE * QK4_0;
137+
const int n_warp = (k + WARP_K - 1) / WARP_K;
138+
GGML_ASSERT(k % 2 == 0);
139+
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
140+
sycl::range<3>(1, 1, WARP_SIZE),
141+
sycl::range<3>(1, 1, WARP_SIZE)),
142+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
143+
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
144+
});
145+
146+
}
147+
128148
template <typename dst_t>
129149
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
130150
dpct::queue_ptr stream) {
@@ -455,7 +475,11 @@ static void convert_unary_sycl(const void *__restrict__ vx,
455475
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
456476
switch (type) {
457477
case GGML_TYPE_Q4_0:
478+
#if defined(GGML_SYCL_INTEL_TARGET)
479+
return dequantize_row_q4_0_sycl_reorder;
480+
#else
458481
return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
482+
#endif
459483
case GGML_TYPE_Q4_1:
460484
return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
461485
case GGML_TYPE_Q5_0:
@@ -502,7 +526,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
502526
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
503527
switch (type) {
504528
case GGML_TYPE_Q4_0:
529+
#if defined(GGML_SYCL_INTEL_TARGET)
530+
return dequantize_row_q4_0_sycl_reorder;
531+
#else
505532
return dequantize_row_q4_0_sycl;
533+
#endif
506534
case GGML_TYPE_Q4_1:
507535
return dequantize_row_q4_1_sycl;
508536
case GGML_TYPE_Q5_0:

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "common.hpp"
1717

1818
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
19+
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
20+
const int iqs, dfloat2 &v);
1921

2022
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
2123
const int iqs, dfloat2 &v) {
@@ -40,6 +42,29 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
4042
#endif // GGML_SYCL_F16
4143
}
4244

45+
static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
46+
const int iqs, dfloat2 &v) {
47+
// const block_q4_0 * x = (const block_q4_0 *) vx;
48+
49+
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
50+
51+
const int vui = *((const uint8_t *)qs+iqs);
52+
53+
v.x() = vui & 0xF;
54+
v.y() = vui >> 4;
55+
56+
#ifdef GGML_SYCL_F16
57+
// v = v - {8.0f, 8.0f};
58+
// v = v * {d, d};
59+
v.s0() = (v.s0() - 8.0f) * d;
60+
v.s1() = (v.s1() - 8.0f) * d;
61+
62+
#else
63+
v.x() = (v.x() - 8.0f) * d;
64+
v.y() = (v.y() - 8.0f) * d;
65+
#endif // GGML_SYCL_F16
66+
}
67+
4368
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
4469
const int iqs, dfloat2 &v) {
4570
const block_q4_1 * x = (const block_q4_1 *) vx;
@@ -167,6 +192,36 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
167192
}
168193
}
169194

195+
template<typename dst_t>
196+
static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
197+
const sycl::nd_item<3> &item_ct1) {
198+
199+
const int64_t i = item_ct1.get_group(2);
200+
auto k=nb32;
201+
// assume 32 threads
202+
const int64_t tid = item_ct1.get_local_id(2);
203+
const int lane_ib = i * WARP_SIZE + tid;
204+
205+
if (lane_ib >= k / QK4_0) {
206+
return;
207+
}
208+
209+
dst_t * y_ptr = yy + lane_ib * QK4_0;
210+
211+
auto qs = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
212+
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
213+
214+
const float d = float(*s_ptr);
215+
216+
#pragma unroll
217+
for (int l = 0; l < QK4_0 / 2; ++l) {
218+
int vq = qs[l];
219+
y_ptr[l + 0] = d * ((vq & 0xF) - 8);
220+
y_ptr[l + 16] = d * ((vq >> 4) - 8);
221+
}
222+
223+
}
224+
170225
template<typename dst_t>
171226
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
172227
const sycl::nd_item<3> &item_ct1) {

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,112 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
9191
}
9292
}
9393

94+
template <int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder>
95+
static void dequantize_mul_mat_vec_reorder(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
96+
const sycl::nd_item<3> &item_ct1) {
97+
// qk = quantized weights per x block
98+
// qr = number of quantized weights per data value in x block
99+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
100+
item_ct1.get_local_id(1);
101+
102+
if (row >= nrows) {
103+
return;
104+
}
105+
106+
const int tid = item_ct1.get_local_id(2);
107+
108+
109+
const int ncols_left = ncols % (QK4_0*WARP_SIZE);
110+
const int ncols_align = ncols - ncols_left;
111+
const int iter_stride = 8*2*GGML_SYCL_DMMV_X;
112+
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter //64/16=4, 512/16/2= 16
113+
const int y_offset = qr == 1 ? 1 : qk/2;
114+
115+
// partial sum for each thread
116+
#ifdef GGML_SYCL_F16
117+
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
118+
#else
119+
float tmp = 0.0f;
120+
#endif // GGML_SYCL_F16
121+
const char *d_ptr = (const char*)vx+ncols*nrows/2;
122+
int i=0;
123+
for (i = 0; i < ncols_align; i += iter_stride) {
124+
const int col = i + vals_per_iter*tid;
125+
const int ib = (row*ncols + col)/qk; // x block index
126+
const int iqs = (col%qk)/qr; // x quant index
127+
const int iybs = col - col%qk; // y block start index
128+
129+
// processing >2 values per i iter is faster for fast GPUs
130+
#pragma unroll
131+
for (int j = 0; j < vals_per_iter; j += 2) {
132+
// process 2 vals per j iter
133+
134+
// dequantize
135+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
136+
dfloat2 v;
137+
dequantize_kernel_recorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
138+
139+
// matrix multiplication
140+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
141+
#ifdef GGML_SYCL_F16
142+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
143+
y[iybs + iqs + j / qr + y_offset]};
144+
145+
tmp += v * t1;
146+
#else
147+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
148+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
149+
#endif // GGML_SYCL_F16
150+
}
151+
}
152+
153+
for (; i < ncols; i += iter_stride) {
154+
if (tid>=ncols_left/QK4_0) continue;
155+
const int col = i + vals_per_iter*tid;
156+
const int ib = (row*ncols + col)/qk; // x block index
157+
const int iqs = (col%qk)/qr; // x quant index
158+
const int iybs = col - col%qk; // y block start index
159+
160+
// processing >2 values per i iter is faster for fast GPUs
161+
#pragma unroll
162+
for (int j = 0; j < vals_per_iter; j += 2) {
163+
// process 2 vals per j iter
164+
165+
// dequantize
166+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
167+
dfloat2 v;
168+
dequantize_kernel_recorder((const void *)d_ptr, ib, (const void *)vx, ib * QK4_0 / 2 +iqs+j/qr, v);
169+
170+
// matrix multiplication
171+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
172+
#ifdef GGML_SYCL_F16
173+
dfloat2 t1{y[iybs + iqs + j / qr + 0],
174+
y[iybs + iqs + j / qr + y_offset]};
175+
176+
tmp += v * t1;
177+
#else
178+
tmp += v.x() * y[iybs + iqs + j / qr + 0];
179+
tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
180+
#endif // GGML_SYCL_F16
181+
}
182+
}
183+
184+
// sum up partial sums and write back result
185+
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
186+
for (int mask = mask_start; mask > 0; mask >>= 1) {
187+
tmp +=
188+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
189+
}
190+
191+
if (tid == 0) {
192+
#ifdef GGML_SYCL_F16
193+
dst[row] = tmp.x() + tmp.y();
194+
#else
195+
dst[row] = tmp;
196+
#endif // GGML_SYCL_F16
197+
}
198+
}
199+
94200
static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
95201
float *dst, const int ncols,
96202
const int nrows,
@@ -760,6 +866,29 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
760866
}
761867

762868

869+
static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y,
870+
float *dst, const int ncols,
871+
const int nrows,
872+
dpct::queue_ptr stream) {
873+
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
874+
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
875+
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
876+
const sycl::range<3> block_nums(1, 1, block_num_y);
877+
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
878+
{
879+
dpct::has_capability_or_fail(stream->get_device(),
880+
{sycl::aspect::fp16});
881+
882+
stream->parallel_for(
883+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
884+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
885+
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
886+
vx, y, dst, ncols, nrows, item_ct1);
887+
});
888+
}
889+
}
890+
891+
763892
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
764893
float *dst, const int ncols,
765894
const int nrows,
@@ -977,7 +1106,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
9771106

9781107
switch (src0->type) {
9791108
case GGML_TYPE_Q4_0:
1109+
#if defined(GGML_SYCL_INTEL_TARGET)
1110+
dequantize_mul_mat_vec_q4_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1111+
#else
9801112
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
1113+
#endif
9811114
break;
9821115
case GGML_TYPE_Q4_1:
9831116
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
@@ -1020,4 +1153,5 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
10201153
GGML_UNUSED(src1_ddq_i);
10211154
GGML_UNUSED(src1_ncols);
10221155
GGML_UNUSED(src1_padded_row_size);
1156+
GGML_UNUSED(ctx);
10231157
}

0 commit comments

Comments
 (0)