Skip to content

Commit dfa6e2b

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: faster IQ2_K, IQ2_KS, IQ2_K_R4 (#716)
* Use bperm trick for iq2_ks gemm -> 7% gain * Use bperm trick for iq2_k gemm -> ~5% gain * Use bperm trick for iq2_k_r4 gemm -> ~3% gain * Use bperm trick for iq2_ks gemv -> ~7% gain * Use bperm trick for iq2_k gemv -> ~3% gain * Use bperm trick for iq2_k_r4 gemv -> ~7% gain --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 3b94f0a commit dfa6e2b

File tree

5 files changed

+190
-29
lines changed

5 files changed

+190
-29
lines changed

ggml/src/ggml-cuda/iqk_cuda_common.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,12 @@ __device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16
127127
return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
128128
}
129129

130+
#ifdef __CUDA_ARCH__
131+
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
132+
const uint32_t * values32 = (const uint32_t *)values;
133+
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
134+
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
135+
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
136+
}
137+
#endif
138+

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,34 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
849849
const uint32_t * q2 = (const uint32_t *)bq2->qs + 8*(i4/4) + 2*(i4%4);
850850
const uint16_t extra = bq2->extra >> (8*(i4/4) + (i4%4)/2);
851851

852+
const uint32_t * scales = (const uint32_t *)bq2->scales;
853+
uint32_t s32 = __vsub4((scales[i4/4] >> 4*(((i4%4)/2)%2)) & 0x0f0f0f0f, 0x08080808);
854+
const int8_t * s8 = (const int8_t *)&s32;
855+
856+
// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
857+
// -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)
858+
859+
#ifdef __CUDA_ARCH__
860+
uint32_t extra32 = uint32_t(extra & 0xff) * 0x01010101;
861+
uint32_t extra32_1 = (extra32 << 2) & 0x44444444;
862+
uint32_t extra32_2 = (extra32 << 0) & 0x44444444;
863+
864+
uint32_t val1, val2;
865+
866+
val1 = ((q2[0] >> 0) & 0x33333333) | extra32_1; val2 = ((q2[1] >> 0) & 0x33333333) | extra32_1;
867+
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
868+
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
869+
int sumi1 = ggml_cuda_dp4a(v2.x, q8_1[1], ggml_cuda_dp4a(v1.x, q8_1[0], 0)) * s8[0];
870+
int sumi3 = ggml_cuda_dp4a(v2.y, q8_3[1], ggml_cuda_dp4a(v1.y, q8_3[0], 0)) * s8[2];
871+
872+
val1 = ((q2[0] >> 2) & 0x33333333) | extra32_2; val2 = ((q2[1] >> 2) & 0x33333333) | extra32_2;
873+
v1 = get_int_from_table_8(val1, iq2nl_values);
874+
v2 = get_int_from_table_8(val2, iq2nl_values);
875+
int sumi2 = ggml_cuda_dp4a(v2.x, q8_2[1], ggml_cuda_dp4a(v1.x, q8_2[0], 0)) * s8[1];
876+
int sumi4 = ggml_cuda_dp4a(v2.y, q8_4[1], ggml_cuda_dp4a(v1.y, q8_4[0], 0)) * s8[3];
877+
878+
#else
879+
852880
const int * all_values = (const int *)iq2k_table;
853881
const int * values;
854882

@@ -857,13 +885,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
857885
uint32_t aux32[2];
858886
int v1, v2;
859887

860-
// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
861-
// -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)
862-
863-
const uint32_t * scales = (const uint32_t *)bq2->scales;
864-
uint32_t s32 = __vsub4((scales[i4/4] >> 4*(((i4%4)/2)%2)) & 0x0f0f0f0f, 0x08080808);
865-
const int8_t * s8 = (const int8_t *)&s32;
866-
867888
aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
868889
v1 = int_from_table_4(aux32[0], values);
869890
v2 = int_from_table_4(aux32[1], values);
@@ -883,6 +904,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
883904
v1 = int_from_table_4(aux32[0], values);
884905
v2 = int_from_table_4(aux32[1], values);
885906
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
907+
#endif
886908

887909
*result += __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
888910
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
@@ -908,14 +930,8 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
908930
const uint16_t * q2 = (const uint16_t *)bq2->qs + 16*(i4/4) + 4*(i4%4);
909931
const uint16_t extra = bq2->extra >> 4*(i4/4);
910932

911-
const int * all_values = (const int *)iq2k_table;
912-
const int * values;
913-
914933
uint32_t val1 = q2[0] | (q2[1] << 16), val2 = q2[2] | (q2[3] << 16);
915934

916-
uint32_t aux32[2];
917-
int v1, v2;
918-
919935
int32_t scales32;
920936
const uint16_t * scales16 = (const uint16_t *)bq2->scales;
921937
scales32 = __vsub4((scales16[i4/4] | (scales16[i4/4] << 12)) & 0x0f0f0f0f, 0x10101010);
@@ -925,6 +941,35 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
925941
s8[2] += ((extra >> 5) & 0x10);
926942
s8[3] += ((extra >> 7) & 0x10);
927943

944+
#ifdef __CUDA_ARCH__
945+
946+
uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101;
947+
948+
uint32_t this_extra = ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040);
949+
uint32_t idx1 = ((val1 >> 0) & 0x33333333) | this_extra;
950+
uint32_t idx2 = ((val2 >> 0) & 0x33333333) | this_extra;
951+
int2 v1 = get_int_from_table_8(idx1, iq2nl_values);
952+
int2 v2 = get_int_from_table_8(idx2, iq2nl_values);
953+
954+
int sumi1 = ggml_cuda_dp4a(v2.x, q8_1[1], ggml_cuda_dp4a(v1.x, q8_1[0], 0)) * s8[0];
955+
int sumi3 = ggml_cuda_dp4a(v2.y, q8_3[1], ggml_cuda_dp4a(v1.y, q8_3[0], 0)) * s8[1];
956+
957+
this_extra = ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040);
958+
idx1 = ((val1 >> 2) & 0x33333333) | this_extra;
959+
idx2 = ((val2 >> 2) & 0x33333333) | this_extra;
960+
v1 = get_int_from_table_8(idx1, iq2nl_values);
961+
v2 = get_int_from_table_8(idx2, iq2nl_values);
962+
963+
int sumi2 = ggml_cuda_dp4a(v2.x, q8_2[1], ggml_cuda_dp4a(v1.x, q8_2[0], 0)) * s8[2];
964+
int sumi4 = ggml_cuda_dp4a(v2.y, q8_4[1], ggml_cuda_dp4a(v1.y, q8_4[0], 0)) * s8[3];
965+
966+
#else
967+
968+
uint32_t aux32[2];
969+
int v1, v2;
970+
const int * all_values = (const int *)iq2k_table;
971+
const int * values;
972+
928973
aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
929974
v1 = int_from_table_4(aux32[0], values);
930975
v2 = int_from_table_4(aux32[1], values);
@@ -944,6 +989,7 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
944989
v1 = int_from_table_4(aux32[0], values);
945990
v2 = int_from_table_4(aux32[1], values);
946991
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
992+
#endif
947993

948994
*result += scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
949995
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
@@ -965,12 +1011,31 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
9651011
int is = ib16%2;
9661012
const int * scales_l = (const int *)bq2->scales;
9671013

968-
const int * all_values = (const int *)iq2k_table;
969-
9701014
int scales = __vsub4(((scales_l[2*(ib32%4)+is] >> 4*(ib32/4)) & 0x0f0f0f0f), 0x08080808);
9711015
const int8_t * s8 = (const int8_t *)&scales;
972-
int2 val1;
1016+
9731017
const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is;
1018+
1019+
#ifdef __CUDA_ARCH__
1020+
1021+
#pragma unroll
1022+
for (int i = 0; i < 4; ++i) {
1023+
uint32_t extra32 = uint32_t((bq2->extra[i+4*is] >> ib32) & 1) * 0x04040404;
1024+
extra32 |= (extra32 << 4);
1025+
uint32_t val1 = ((q2[i] >> 0) & 0x33333333) | extra32;
1026+
uint32_t val2 = ((q2[i] >> 2) & 0x33333333) | extra32;
1027+
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
1028+
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
1029+
int sumi = 0;
1030+
sumi = ggml_cuda_dp4a(v1.x, q8[0], ggml_cuda_dp4a(v2.x, q8[1], sumi));
1031+
sumi = ggml_cuda_dp4a(v1.y, q8[2], ggml_cuda_dp4a(v2.y, q8[3], sumi));
1032+
const float d = __half2float(bq2->d[i]) * d8;
1033+
result[i] += d * sumi * s8[i];
1034+
}
1035+
1036+
#else
1037+
const int * all_values = (const int *)iq2k_table;
1038+
int2 val1;
9741039
int aux32[2];
9751040
#pragma unroll
9761041
for (int i = 0; i < 4; ++i) {
@@ -989,6 +1054,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
9891054
const float d = __half2float(bq2->d[i]) * d8;
9901055
result[i] += d * sumi1 * s8[i];
9911056
}
1057+
#endif
9921058
}
9931059

9941060
#define VDR_IQ3_K_Q8_1_MMVQ 4

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
187187
break;
188188
case GGML_TYPE_IQ2_K:
189189
case GGML_TYPE_IQ2_K_R4:
190-
mmq_supported = ne11 < 2048;
190+
mmq_supported = ne11 <= 3072;
191191
break;
192192
case GGML_TYPE_IQ3_K:
193193
case GGML_TYPE_IQ4_K:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,11 +2566,45 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
25662566
float * x_df = (float *) (x_qs + txs.qs);
25672567
#endif // INT8_MMA_AVAILABLE
25682568

2569-
const int * all_values = (const int *)iq2k_table;
2570-
25712569
const int kqsx = threadIdx.x%16;
25722570

2573-
#pragma unroll
2571+
#ifdef __CUDA_ARCH__
2572+
#pragma unroll
2573+
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
2574+
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
2575+
2576+
if (need_check) {
2577+
i = min(i, i_max);
2578+
}
2579+
2580+
const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0;
2581+
2582+
uint16_t extra = bxi->extra >> 4*(kqsx/8);
2583+
int q2 = get_int_b2(bxi->qs, kqsx);
2584+
2585+
uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101;
2586+
uint32_t val1 = ((q2 >> 0) & 0x33333333) | ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040);
2587+
uint32_t val2 = ((q2 >> 2) & 0x33333333) | ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040);
2588+
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
2589+
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
2590+
2591+
#ifdef INT8_MMA_AVAILABLE
2592+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
2593+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
2594+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
2595+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
2596+
#else
2597+
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
2598+
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
2599+
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
2600+
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
2601+
#endif // INT8_MMA_AVAILABLE
2602+
}
2603+
2604+
#else // __CUDA_ARCH__
2605+
2606+
const int * all_values = (const int *)iq2k_table;
2607+
#pragma unroll
25742608
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
25752609
int i = i0 + 2*threadIdx.y + threadIdx.x/16;
25762610

@@ -2595,6 +2629,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
25952629
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5));
25962630
#endif // INT8_MMA_AVAILABLE
25972631
}
2632+
#endif // __CUDA_ARCH__
25982633

25992634
#pragma unroll
26002635
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
@@ -2635,7 +2670,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
26352670
constexpr int qstep = 8;
26362671
const int kqsx = threadIdx.x % qstep;
26372672

2638-
#pragma unroll
2673+
#pragma unroll
26392674
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
26402675
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;
26412676

@@ -2645,13 +2680,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
26452680

26462681
const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0;
26472682

2648-
auto all_values = (const int *)iq2k_table;
2649-
26502683
const float d = bxi->d;
2651-
26522684
uint16_t extra = bxi->extra >> (kqsx/4);
26532685

2654-
#pragma unroll
2686+
#ifdef __CUDA_ARCH__
2687+
2688+
uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 };
2689+
#pragma unroll
2690+
for (int l = 0; l < qstep/4; ++l) {
2691+
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
2692+
uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 2) & 0x44444444);
2693+
uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 0) & 0x44444444);
2694+
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
2695+
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
2696+
#ifdef INT8_MMA_AVAILABLE
2697+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x;
2698+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x;
2699+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y;
2700+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y;
2701+
#else
2702+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x;
2703+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x;
2704+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y;
2705+
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y;
2706+
#endif // INT8_MMA_AVAILABLE
2707+
}
2708+
2709+
#else
2710+
2711+
auto all_values = (const int *)iq2k_table;
2712+
2713+
#pragma unroll
26552714
for (int l = 0; l < qstep/4; ++l) {
26562715

26572716
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
@@ -2670,6 +2729,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
26702729

26712730
extra >>= 8;
26722731
}
2732+
#endif // __CUDA_ARCH__
26732733

26742734
#ifdef INT8_MMA_AVAILABLE
26752735
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8);

ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1414
float * x_df = (float *) (x_qs + txs.qs);
1515
#endif // INT8_MMA_AVAILABLE
1616

17-
const int * all_values = (const int *)iq2k_table;
18-
1917
const int kqsx = threadIdx.x/4; // 0...7 -> block of 32
2018

2119
#pragma unroll
@@ -32,10 +30,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
3230

3331
const float d = __half2float(bxi->d[ir]);
3432

35-
#pragma unroll
33+
#ifdef __CUDA_ARCH__
34+
#pragma unroll
35+
for (int l = 0; l < 2; ++l) {
36+
37+
uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404;
38+
extra = extra | (extra << 4);
39+
40+
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
41+
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra;
42+
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra;
43+
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
44+
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
45+
46+
#ifdef INT8_MMA_AVAILABLE
47+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
48+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
49+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
50+
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
51+
#else
52+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
53+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
54+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
55+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
56+
#endif // INT8_MMA_AVAILABLE
57+
}
58+
59+
#else
60+
#pragma unroll
3661
for (int l = 0; l < 2; ++l) {
3762

38-
auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
63+
auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
3964

4065
const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
4166

@@ -51,6 +76,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
5176
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
5277
#endif // INT8_MMA_AVAILABLE
5378
}
79+
#endif // __CUDA_ARCH__
5480

5581
int is = 8*kqsx + ir;
5682
float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);

0 commit comments

Comments
 (0)