Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions ggml/src/ggml-cuda/vecdotq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}

// q4 contains 8 indices with 4 bit each.
// This function selects those bytes from table that are at those indices and returns them as int2.
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
#if defined(GGML_USE_HIP)
// Load the 16-byte table into four 32-bit unsigned integers.
const uint32_t *values = (const uint32_t *)table;

const uint32_t q_even = q4;
const uint32_t q_odd = (q4 >> 4);

// Perform lookups in the lower half of the table (indices 0-7).
uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);

// Perform lookups in the upper half of the table (indices 8-15).
uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);

// Select between the low and high results based on the MSB of each index nibble.
uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);

return make_int2(res_x, res_y);
#elif !defined(GGML_USE_MUSA)
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
const uint32_t * table32 = (const uint32_t *) table;

// __byte_perm selects bytes based on the lower 16 bits in its third argument.
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
uint32_t tmp[2];
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
#pragma unroll
for (uint32_t i = 0; i < 2; ++i) {
const uint32_t shift = 16 * i;

const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
}

// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
// However, for the result we need ints with all even/odd 4 bit indices in q4.
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
#else
// Generic implementation.
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
Expand All @@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);

return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
#endif
}

// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
Expand Down
Loading