Skip to content

Commit 78de773

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: faster prompt processing for 4-bit quants (#713)
* Use __byte_perm in get_int_from_table_16 * Use get_int_from_table_16 everywhere for 4-bit quants --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0cb6696 commit 78de773

File tree

4 files changed

+76
-63
lines changed

4 files changed

+76
-63
lines changed

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,25 @@ __device__ __forceinline__ void vec_dot_iq4_k_q8_1(
246246
}
247247

248248
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
249+
#if defined(__CUDA_ARCH__)
250+
uint32_t v1, v2, v3, v4, mask;
251+
const uint32_t * values32 = (const uint32_t *)values;
252+
253+
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
254+
// Perform lookups in the lower half of the table (indices 0-7).
255+
v1 = __byte_perm(values32[0], values32[1], q4);
256+
// Perform lookups in the upper half of the table (indices 8-15).
257+
v2 = __byte_perm(values32[2], values32[3], q4);
258+
// Select between the low and high results based on the MSB of each index nibble.
259+
v3 = __byte_perm(v1, v2, mask);
260+
// Same for the upper part of q4.
261+
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
262+
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
263+
v4 = __byte_perm(v1, v2, mask >> 16);
264+
265+
// Mix the results to get the final int2.
266+
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
267+
#else
249268
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
250269
const int8_t * q0_8 = (const int8_t *) &q0_32;
251270
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
@@ -255,6 +274,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
255274
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
256275

257276
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
277+
#endif
258278
}
259279

260280
__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
@@ -389,19 +409,18 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(
389409

390410
float scale = *(const float *)vbq;
391411
const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx;
392-
const uint8_t * all_values = (const uint8_t *)iq4k_values;
393412

394413
// iqs is 0...28
395414
const int ib32 = iqs/4; // Why iqs/4 ?
396415
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
397416
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
398417
const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
399-
int v1, v2;
418+
auto values = iq4k_values + ((bq4->scales[ib32] & 1) << 4);
400419
int sumi = 0;
401420
for (int j = 0; j < 4; ++j) {
402-
get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
403-
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
404-
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
421+
auto v = get_int_from_table_16(q4[j], values);
422+
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
423+
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
405424
}
406425
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
407426
}
@@ -560,7 +579,6 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
560579

561580
float scale = *(const float *)vbq;
562581
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
563-
const uint8_t * all_values = (const uint8_t *)iq4k_values;
564582

565583
// iqs is 0...28
566584
const int ib32 = iqs/4; // Why iqs/4 ?
@@ -569,14 +587,14 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
569587
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
570588
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
571589
const float dl = scale * ((ls & 254) - 127);
572-
int v1, v2;
590+
auto values = iq4k_values + ((ls & 1) << 4);
573591
int sumi = 0;
574592
for (int j = 0; j < 4; ++j) {
575593
uint32_t aux32 = q4[j] & 0xfffefffe;
576594
aux32 ^= (aux32 >> 1);
577-
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
578-
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
579-
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
595+
auto v = get_int_from_table_16(aux32, values);
596+
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
597+
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
580598
}
581599
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
582600
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,9 +2509,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
25092509
const int kbx = 0; // threadIdx.x / QI4_XS
25102510
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
25112511

2512-
uint32_t aux32[2];
2513-
auto a8 = (const uint8_t *)aux32;
2514-
25152512
#pragma unroll
25162513
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
25172514
int i = i0 + threadIdx.y;
@@ -2523,15 +2520,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
25232520
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;
25242521

25252522
const int q4 = get_int_b4(bxi->qs, kqsx);
2526-
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
2527-
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
2523+
const int2 v = get_int_from_table_16(q4);
25282524
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
25292525
#ifdef INT8_MMA_AVAILABLE
2530-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
2531-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
2526+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2527+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
25322528
#else
2533-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
2534-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
2529+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2530+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
25352531
#endif // INT8_MMA_AVAILABLE
25362532
}
25372533

@@ -2842,9 +2838,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
28422838

28432839
const int kqsx = threadIdx.x / 4;
28442840

2845-
uint32_t aux32[2];
2846-
auto a8 = (const uint8_t *)aux32;
2847-
28482841
#pragma unroll
28492842
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
28502843
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -2857,19 +2850,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
28572850
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
28582851
const int ls = (bxi->scales[kqsx] & 254) - 127;
28592852

2860-
auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
2853+
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
28612854

28622855
#pragma unroll
28632856
for (int j = 0; j < 4; ++j) {
28642857
const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
2865-
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
2866-
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
2858+
const int2 v = get_int_from_table_16(q4, values);
28672859
#ifdef INT8_MMA_AVAILABLE
2868-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
2869-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
2860+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
2861+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
28702862
#else
2871-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
2872-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
2863+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
2864+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
28732865
#endif // INT8_MMA_AVAILABLE
28742866
}
28752867
#ifdef INT8_MMA_AVAILABLE
@@ -2896,9 +2888,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
28962888

28972889
const int kqsx = threadIdx.x/4;
28982890

2899-
uint32_t aux32[2];
2900-
const uint8_t * a8 = (const uint8_t *)aux32;
2901-
29022891
#pragma unroll
29032892
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
29042893
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -2913,19 +2902,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
29132902
const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0;
29142903

29152904
const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
2916-
auto values = iq4k_table + ((bxi->scales[4*kqsx+ir] & 1) << 8);
2905+
auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4);
2906+
29172907
#pragma unroll
29182908
for (int j = 0; j < 4; ++j) {
29192909
const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir);
2920-
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
2921-
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
2910+
const int2 v = get_int_from_table_16(q4, values);
29222911
const int k0 = 8*kqsx + 4*(j%2) + j/2;
29232912
#ifdef INT8_MMA_AVAILABLE
2924-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, values);
2925-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = int_from_table_x(a8+4, values);
2913+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2914+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y;
29262915
#else
2927-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, values);
2928-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = int_from_table_x(a8+4, values);
2916+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2917+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y;
29292918
#endif // INT8_MMA_AVAILABLE
29302919
}
29312920
#ifdef INT8_MMA_AVAILABLE

ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1414

1515
const int kqsx = threadIdx.x / 4;
1616

17-
uint32_t aux32[2];
18-
auto a8 = (const uint8_t *)aux32;
19-
2017
#pragma unroll
2118
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
2219
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
@@ -31,20 +28,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
3128
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
3229
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
3330

34-
auto values = iq4k_table + ((ls & 1) << 8);
31+
auto values = iq4k_values + ((ls & 1) << 4);
3532

3633
#pragma unroll
3734
for (int j = 0; j < 4; ++j) {
3835
uint32_t val = q4[j] & 0xfffefffe;
3936
val = val ^ (val >> 1);
40-
aux32[0] = (val >> 0) & 0x0f0f0f0f;
41-
aux32[1] = (val >> 4) & 0x0f0f0f0f;
37+
auto v = get_int_from_table_16(val, values);
4238
#ifdef INT8_MMA_AVAILABLE
43-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
44-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
39+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
40+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
4541
#else
46-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
47-
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
42+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
43+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
4844
#endif // INT8_MMA_AVAILABLE
4945
}
5046
#ifdef INT8_MMA_AVAILABLE

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,21 +1126,26 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
11261126
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
11271127
}
11281128

1129-
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
1130-
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
1131-
const int8_t * q0_8 = (const int8_t *) &q0_32;
1132-
const char4 val0_8 = make_char4(
1133-
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
1134-
1135-
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
1136-
const int8_t * q1_8 = (const int8_t *) &q1_32;
1137-
const char4 val1_8 = make_char4(
1138-
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
1139-
1140-
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
1141-
}
1142-
11431129
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
1130+
#if defined(__CUDA_ARCH__)
1131+
uint32_t v1, v2, v3, v4, mask;
1132+
const uint32_t * values32 = (const uint32_t *)values;
1133+
1134+
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
1135+
// Perform lookups in the lower half of the table (indices 0-7).
1136+
v1 = __byte_perm(values32[0], values32[1], q4);
1137+
// Perform lookups in the upper half of the table (indices 8-15).
1138+
v2 = __byte_perm(values32[2], values32[3], q4);
1139+
// Select between the low and high results based on the MSB of each index nibble.
1140+
v3 = __byte_perm(v1, v2, mask);
1141+
// Same for the upper part of q4.
1142+
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
1143+
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
1144+
v4 = __byte_perm(v1, v2, mask >> 16);
1145+
1146+
// Mix the results to get the final int2.
1147+
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
1148+
#else
11441149
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
11451150
const int8_t * q0_8 = (const int8_t *) &q0_32;
11461151
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
@@ -1150,6 +1155,11 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
11501155
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);
11511156

11521157
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
1158+
#endif
1159+
}
1160+
1161+
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
1162+
return get_int_from_table_16(q4, kvalues_iq4nl);
11531163
}
11541164

11551165
#define VDR_IQ4_NL_Q8_1_MMVQ 2

0 commit comments

Comments
 (0)