Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/iqk_cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,12 @@ __device__ __forceinline__ int int_from_table_x(const uint8_t * a8, const uint16
return values[a8[0] | (a8[1] << 4)] | (values[a8[2] | (a8[3] << 4)] << 16);
}

#ifdef __CUDA_ARCH__
static __device__ __forceinline__ int2 get_int_from_table_8(const int & q4, const int8_t * values) {
const uint32_t * values32 = (const uint32_t *)values;
uint32_t v1 = __byte_perm(values32[0], values32[1], q4);
uint32_t v2 = __byte_perm(values32[0], values32[1], q4 >> 16);
return make_int2(__byte_perm(v1, v2, 0x6420), __byte_perm(v1, v2, 0x7531));
}
#endif

98 changes: 82 additions & 16 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,34 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
const uint32_t * q2 = (const uint32_t *)bq2->qs + 8*(i4/4) + 2*(i4%4);
const uint16_t extra = bq2->extra >> (8*(i4/4) + (i4%4)/2);

const uint32_t * scales = (const uint32_t *)bq2->scales;
uint32_t s32 = __vsub4((scales[i4/4] >> 4*(((i4%4)/2)%2)) & 0x0f0f0f0f, 0x08080808);
const int8_t * s8 = (const int8_t *)&s32;

// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
// -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)

#ifdef __CUDA_ARCH__
uint32_t extra32 = uint32_t(extra & 0xff) * 0x01010101;
uint32_t extra32_1 = (extra32 << 2) & 0x44444444;
uint32_t extra32_2 = (extra32 << 0) & 0x44444444;

uint32_t val1, val2;

val1 = ((q2[0] >> 0) & 0x33333333) | extra32_1; val2 = ((q2[1] >> 0) & 0x33333333) | extra32_1;
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
int sumi1 = ggml_cuda_dp4a(v2.x, q8_1[1], ggml_cuda_dp4a(v1.x, q8_1[0], 0)) * s8[0];
int sumi3 = ggml_cuda_dp4a(v2.y, q8_3[1], ggml_cuda_dp4a(v1.y, q8_3[0], 0)) * s8[2];

val1 = ((q2[0] >> 2) & 0x33333333) | extra32_2; val2 = ((q2[1] >> 2) & 0x33333333) | extra32_2;
v1 = get_int_from_table_8(val1, iq2nl_values);
v2 = get_int_from_table_8(val2, iq2nl_values);
int sumi2 = ggml_cuda_dp4a(v2.x, q8_2[1], ggml_cuda_dp4a(v1.x, q8_2[0], 0)) * s8[1];
int sumi4 = ggml_cuda_dp4a(v2.y, q8_4[1], ggml_cuda_dp4a(v1.y, q8_4[0], 0)) * s8[3];

#else

const int * all_values = (const int *)iq2k_table;
const int * values;

Expand All @@ -857,13 +885,6 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
uint32_t aux32[2];
int v1, v2;

// Block of 16: (32*(4*(i4/4)+k)+8*(i4%4))/16 = 8*(i4/4) + 2*k + (i4%4)/2
// -> scales_l[4*(i4/4) + k] >> 4*(((i4%4)/2)%2)

const uint32_t * scales = (const uint32_t *)bq2->scales;
uint32_t s32 = __vsub4((scales[i4/4] >> 4*(((i4%4)/2)%2)) & 0x0f0f0f0f, 0x08080808);
const int8_t * s8 = (const int8_t *)&s32;

aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
v1 = int_from_table_4(aux32[0], values);
v2 = int_from_table_4(aux32[1], values);
Expand All @@ -883,6 +904,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_q8_1(
v1 = int_from_table_4(aux32[0], values);
v2 = int_from_table_4(aux32[1], values);
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
#endif

*result += __half2float(bq2->d) * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
Expand All @@ -908,14 +930,8 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
const uint16_t * q2 = (const uint16_t *)bq2->qs + 16*(i4/4) + 4*(i4%4);
const uint16_t extra = bq2->extra >> 4*(i4/4);

const int * all_values = (const int *)iq2k_table;
const int * values;

uint32_t val1 = q2[0] | (q2[1] << 16), val2 = q2[2] | (q2[3] << 16);

uint32_t aux32[2];
int v1, v2;

int32_t scales32;
const uint16_t * scales16 = (const uint16_t *)bq2->scales;
scales32 = __vsub4((scales16[i4/4] | (scales16[i4/4] << 12)) & 0x0f0f0f0f, 0x10101010);
Expand All @@ -925,6 +941,35 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
s8[2] += ((extra >> 5) & 0x10);
s8[3] += ((extra >> 7) & 0x10);

#ifdef __CUDA_ARCH__

uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101;

uint32_t this_extra = ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040);
uint32_t idx1 = ((val1 >> 0) & 0x33333333) | this_extra;
uint32_t idx2 = ((val2 >> 0) & 0x33333333) | this_extra;
int2 v1 = get_int_from_table_8(idx1, iq2nl_values);
int2 v2 = get_int_from_table_8(idx2, iq2nl_values);

int sumi1 = ggml_cuda_dp4a(v2.x, q8_1[1], ggml_cuda_dp4a(v1.x, q8_1[0], 0)) * s8[0];
int sumi3 = ggml_cuda_dp4a(v2.y, q8_3[1], ggml_cuda_dp4a(v1.y, q8_3[0], 0)) * s8[1];

this_extra = ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040);
idx1 = ((val1 >> 2) & 0x33333333) | this_extra;
idx2 = ((val2 >> 2) & 0x33333333) | this_extra;
v1 = get_int_from_table_8(idx1, iq2nl_values);
v2 = get_int_from_table_8(idx2, iq2nl_values);

int sumi2 = ggml_cuda_dp4a(v2.x, q8_2[1], ggml_cuda_dp4a(v1.x, q8_2[0], 0)) * s8[2];
int sumi4 = ggml_cuda_dp4a(v2.y, q8_4[1], ggml_cuda_dp4a(v1.y, q8_4[0], 0)) * s8[3];

#else

uint32_t aux32[2];
int v1, v2;
const int * all_values = (const int *)iq2k_table;
const int * values;

aux32[0] = ((val1 >> 0) & 0x03030303); aux32[1] = ((val2 >> 0) & 0x03030303); values = all_values + ((extra & 0x01) << 8);
v1 = int_from_table_4(aux32[0], values);
v2 = int_from_table_4(aux32[1], values);
Expand All @@ -944,6 +989,7 @@ __device__ __forceinline__ void vec_dot_iq2_ks_q8_1(
v1 = int_from_table_4(aux32[0], values);
v2 = int_from_table_4(aux32[1], values);
int sumi4 = ggml_cuda_dp4a(v2, q8_4[1], ggml_cuda_dp4a(v1, q8_4[0], 0)) * s8[3];
#endif

*result += scale * (__low2float(bq8_1[4*(i4/4)+0].ds) * sumi1
+ __low2float(bq8_1[4*(i4/4)+1].ds) * sumi2
Expand All @@ -965,12 +1011,31 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
int is = ib16%2;
const int * scales_l = (const int *)bq2->scales;

const int * all_values = (const int *)iq2k_table;

int scales = __vsub4(((scales_l[2*(ib32%4)+is] >> 4*(ib32/4)) & 0x0f0f0f0f), 0x08080808);
const int8_t * s8 = (const int8_t *)&scales;
int2 val1;

const int * q2 = (const int *)bq2->qs + 8*ib32 + 4*is;

#ifdef __CUDA_ARCH__

#pragma unroll
for (int i = 0; i < 4; ++i) {
uint32_t extra32 = uint32_t((bq2->extra[i+4*is] >> ib32) & 1) * 0x04040404;
extra32 |= (extra32 << 4);
uint32_t val1 = ((q2[i] >> 0) & 0x33333333) | extra32;
uint32_t val2 = ((q2[i] >> 2) & 0x33333333) | extra32;
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
int sumi = 0;
sumi = ggml_cuda_dp4a(v1.x, q8[0], ggml_cuda_dp4a(v2.x, q8[1], sumi));
sumi = ggml_cuda_dp4a(v1.y, q8[2], ggml_cuda_dp4a(v2.y, q8[3], sumi));
const float d = __half2float(bq2->d[i]) * d8;
result[i] += d * sumi * s8[i];
}

#else
const int * all_values = (const int *)iq2k_table;
int2 val1;
int aux32[2];
#pragma unroll
for (int i = 0; i < 4; ++i) {
Expand All @@ -989,6 +1054,7 @@ __device__ __forceinline__ void vec_dot_iq2_k_r4_q8_1(
const float d = __half2float(bq2->d[i]) * d8;
result[i] += d * sumi1 * s8[i];
}
#endif
}

#define VDR_IQ3_K_Q8_1_MMVQ 4
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
break;
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_K_R4:
mmq_supported = ne11 < 2048;
mmq_supported = ne11 <= 3072;
break;
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
Expand Down
76 changes: 68 additions & 8 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2566,11 +2566,45 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int * all_values = (const int *)iq2k_table;

const int kqsx = threadIdx.x%16;

#pragma unroll
#ifdef __CUDA_ARCH__
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
int i = i0 + 2*threadIdx.y + threadIdx.x/16;

if (need_check) {
i = min(i, i_max);
}

const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0;

uint16_t extra = bxi->extra >> 4*(kqsx/8);
int q2 = get_int_b2(bxi->qs, kqsx);

uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101;
uint32_t val1 = ((q2 >> 0) & 0x33333333) | ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040);
uint32_t val2 = ((q2 >> 2) & 0x33333333) | ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040);
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);

#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = v2.y;
#endif // INT8_MMA_AVAILABLE
}

#else // __CUDA_ARCH__

const int * all_values = (const int *)iq2k_table;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) {
int i = i0 + 2*threadIdx.y + threadIdx.x/16;

Expand All @@ -2595,6 +2629,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5));
#endif // INT8_MMA_AVAILABLE
}
#endif // __CUDA_ARCH__

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
Expand Down Expand Up @@ -2635,7 +2670,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
constexpr int qstep = 8;
const int kqsx = threadIdx.x % qstep;

#pragma unroll
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) {
int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep;

Expand All @@ -2645,13 +2680,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0;

auto all_values = (const int *)iq2k_table;

const float d = bxi->d;

uint16_t extra = bxi->extra >> (kqsx/4);

#pragma unroll
#ifdef __CUDA_ARCH__

uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 };
#pragma unroll
for (int l = 0; l < qstep/4; ++l) {
const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 2) & 0x44444444);
uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 0) & 0x44444444);
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y;
#endif // INT8_MMA_AVAILABLE
}

#else

auto all_values = (const int *)iq2k_table;

#pragma unroll
for (int l = 0; l < qstep/4; ++l) {

const int ql = get_int_b4(bxi->qs, kqsx + qstep*l);
Expand All @@ -2670,6 +2729,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

extra >>= 8;
}
#endif // __CUDA_ARCH__

#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8);
Expand Down
34 changes: 30 additions & 4 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_r4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int * all_values = (const int *)iq2k_table;

const int kqsx = threadIdx.x/4; // 0...7 -> block of 32

#pragma unroll
Expand All @@ -32,10 +30,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const float d = __half2float(bxi->d[ir]);

#pragma unroll
#ifdef __CUDA_ARCH__
#pragma unroll
for (int l = 0; l < 2; ++l) {

uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404;
extra = extra | (extra << 4);

const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);
uint32_t val1 = ((ql >> 0) & 0x33333333) | extra;
uint32_t val2 = ((ql >> 2) & 0x33333333) | extra;
int2 v1 = get_int_from_table_8(val1, iq2nl_values);
int2 v2 = get_int_from_table_8(val2, iq2nl_values);

#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y;
#endif // INT8_MMA_AVAILABLE
}

#else
#pragma unroll
for (int l = 0; l < 2; ++l) {

auto values_l = all_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);
auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8);

const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l);

Expand All @@ -51,6 +76,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l);
#endif // INT8_MMA_AVAILABLE
}
#endif // __CUDA_ARCH__

int is = 8*kqsx + ir;
float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8);
Expand Down