Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,25 @@ __device__ __forceinline__ void vec_dot_iq4_k_q8_1(
}

static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
#if defined(__CUDA_ARCH__)
uint32_t v1, v2, v3, v4, mask;
const uint32_t * values32 = (const uint32_t *)values;

mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
// Perform lookups in the lower half of the table (indices 0-7).
v1 = __byte_perm(values32[0], values32[1], q4);
// Perform lookups in the upper half of the table (indices 8-15).
v2 = __byte_perm(values32[2], values32[3], q4);
// Select between the low and high results based on the MSB of each index nibble.
v3 = __byte_perm(v1, v2, mask);
// Same for the upper part of q4.
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
v4 = __byte_perm(v1, v2, mask >> 16);

// Mix the results to get the final int2.
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
#else
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
Expand All @@ -255,6 +274,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);

return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}

__device__ __forceinline__ void vec_dot_iq4_k_r4_q8_1(
Expand Down Expand Up @@ -389,19 +409,18 @@ __device__ __forceinline__ void vec_dot_iq4_ks_q8_1(

float scale = *(const float *)vbq;
const block_iq4_ks * bq4 = (const block_iq4_ks *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;

// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
const float dl = scale * ((bq4->scales[ib32] & 254) - 127);
int v1, v2;
auto values = iq4k_values + ((bq4->scales[ib32] & 1) << 4);
int sumi = 0;
for (int j = 0; j < 4; ++j) {
get_int_from_table_16_shift(q4[j], bq4->scales[ib32] & 1, all_values, v1, v2);
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
auto v = get_int_from_table_16(q4[j], values);
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
Expand Down Expand Up @@ -560,7 +579,6 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(

float scale = *(const float *)vbq;
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
const uint8_t * all_values = (const uint8_t *)iq4k_values;

// iqs is 0...28
const int ib32 = iqs/4; // Why iqs/4 ?
Expand All @@ -569,14 +587,14 @@ __device__ __forceinline__ void vec_dot_iq4_kss_q8_1(
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
const float dl = scale * ((ls & 254) - 127);
int v1, v2;
auto values = iq4k_values + ((ls & 1) << 4);
int sumi = 0;
for (int j = 0; j < 4; ++j) {
uint32_t aux32 = q4[j] & 0xfffefffe;
aux32 ^= (aux32 >> 1);
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
auto v = get_int_from_table_16(aux32, values);
sumi = ggml_cuda_dp4a(v.x, q8[j+0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[j+4], sumi);
}
*result += dl * __low2float(bq8_1[ib32].ds) * sumi;
}
Expand Down
47 changes: 18 additions & 29 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2509,9 +2509,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int kbx = 0; // threadIdx.x / QI4_XS
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS

uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
Expand All @@ -2523,15 +2520,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx;

const int q4 = get_int_b4(bxi->qs, kqsx);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int2 v = get_int_from_table_16(q4);
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, iq4k_table);
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = int_from_table_x(a8+4, iq4k_table);
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}

Expand Down Expand Up @@ -2842,9 +2838,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const int kqsx = threadIdx.x / 4;

uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
Expand All @@ -2857,19 +2850,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
const int ls = (bxi->scales[kqsx] & 254) - 127;

auto values = iq4k_table + ((bxi->scales[kqsx] & 1) << 8);
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);

#pragma unroll
for (int j = 0; j < 4; ++j) {
const int q4 = get_int_b4(bxi->qs, 4*kqsx+j);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int2 v = get_int_from_table_16(q4, values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
Expand All @@ -2896,9 +2888,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const int kqsx = threadIdx.x/4;

uint32_t aux32[2];
const uint8_t * a8 = (const uint8_t *)aux32;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
Expand All @@ -2913,19 +2902,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0;

const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127;
auto values = iq4k_table + ((bxi->scales[4*kqsx+ir] & 1) << 8);
auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4);

#pragma unroll
for (int j = 0; j < 4; ++j) {
const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir);
aux32[0] = (q4 >> 0) & 0x0f0f0f0f;
aux32[1] = (q4 >> 4) & 0x0f0f0f0f;
const int2 v = get_int_from_table_16(q4, values);
const int k0 = 8*kqsx + 4*(j%2) + j/2;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = int_from_table_x(a8+4, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = int_from_table_x(a8+4, values);
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
Expand Down
16 changes: 6 additions & 10 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const int kqsx = threadIdx.x / 4;

uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
Expand All @@ -31,20 +28,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;

auto values = iq4k_table + ((ls & 1) << 8);
auto values = iq4k_values + ((ls & 1) << 4);

#pragma unroll
for (int j = 0; j < 4; ++j) {
uint32_t val = q4[j] & 0xfffefffe;
val = val ^ (val >> 1);
aux32[0] = (val >> 0) & 0x0f0f0f0f;
aux32[1] = (val >> 4) & 0x0f0f0f0f;
auto v = get_int_from_table_16(val, values);
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
Expand Down
38 changes: 24 additions & 14 deletions ggml/src/ggml-cuda/vecdotq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1126,21 +1126,26 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}

static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);

const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);

return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}

static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) {
#if defined(__CUDA_ARCH__)
uint32_t v1, v2, v3, v4, mask;
const uint32_t * values32 = (const uint32_t *)values;

mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
// Perform lookups in the lower half of the table (indices 0-7).
v1 = __byte_perm(values32[0], values32[1], q4);
// Perform lookups in the upper half of the table (indices 8-15).
v2 = __byte_perm(values32[2], values32[3], q4);
// Select between the low and high results based on the MSB of each index nibble.
v3 = __byte_perm(v1, v2, mask);
// Same for the upper part of q4.
v1 = __byte_perm(values32[0], values32[1], q4 >> 16);
v2 = __byte_perm(values32[2], values32[3], q4 >> 16);
v4 = __byte_perm(v1, v2, mask >> 16);

// Mix the results to get the final int2.
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
#else
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]);
Expand All @@ -1150,6 +1155,11 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]);

return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}

static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
return get_int_from_table_16(q4, kvalues_iq4nl);
}

#define VDR_IQ4_NL_Q8_1_MMVQ 2
Expand Down