Skip to content

Commit 741934c

Browse files
revise documentation
1 parent f011aac commit 741934c

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
2828
return ((const int *) x)[i32]; // assume at least 4 byte alignment
2929
}
3030

31+
// q4 contains 8 indices with 4 bit each.
32+
// This function selects those bytes from table that are at those indices and returns them as int2.
33+
// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
3134
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
3235
#if defined(GGML_USE_HIP)
3336
// Load the 16-byte table into four 32-bit unsigned integers.
@@ -51,25 +54,32 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
5154
uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
5255

5356
return make_int2(res_x, res_y);
54-
#elif defined(__CUDA_ARCH__)
55-
uint32_t v1, v2, v3, v4, mask;
56-
const uint32_t *values = (const uint32_t *)table;
57+
#elif !defined(GGML_USE_MUSA)
58+
// CUDA does not have an instruction for selecting bytes with 4 bit indices.
59+
// However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
60+
const uint32_t * table32 = (const uint32_t *) table;
61+
62+
// __byte_perm selects bytes based on the lower 16 bits in its third argument.
63+
// Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
64+
// To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
65+
// Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
66+
uint32_t tmp[2];
67+
const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
68+
#pragma unroll
69+
for (uint32_t i = 0; i < 2; ++i) {
70+
const uint32_t shift = 16 * i;
5771

58-
mask = (0x32103210 | ((q4 & 0x88888888) >> 1));
59-
// Perform lookups in the lower half of the table (indices 0-7).
60-
v1 = __byte_perm(values[0], values[1], q4);
61-
// Perform lookups in the upper half of the table (indices 8-15).
62-
v2 = __byte_perm(values[2], values[3], q4);
63-
// Select between the low and high results based on the MSB of each index nibble.
64-
v3 = __byte_perm(v1, v2, mask);
65-
// Same for the upper part of q4.
66-
v1 = __byte_perm(values[0], values[1], q4 >> 16);
67-
v2 = __byte_perm(values[2], values[3], q4 >> 16);
68-
v4 = __byte_perm(v1, v2, mask >> 16);
69-
70-
// Mix the results to get the final int2.
71-
return make_int2(__byte_perm(v3, v4, 0x6420), __byte_perm(v3, v4, 0x7531));
72+
const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
73+
const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
74+
tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
75+
}
76+
77+
// tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
78+
// However, for the result we need ints with all even/odd 4 bit indices in q4.
79+
// Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
80+
return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
7281
#else
82+
// Generic implementation.
7383
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
7484
const int8_t * q0_8 = (const int8_t *) &q0_32;
7585
const char4 val0_8 = make_char4(

0 commit comments

Comments
 (0)