@@ -28,6 +28,9 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
2828 return ((const int *) x)[i32 ]; // assume at least 4 byte alignment
2929}
3030
31+ // q4 contains 8 indices with 4 bit each.
32+ // This function selects those bytes from table that are at those indices and returns them as int2.
33+ // The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
3134static __device__ __forceinline__ int2 get_int_from_table_16 (const int & q4, const int8_t * table) {
3235#if defined(GGML_USE_HIP)
3336 // Load the 16-byte table into four 32-bit unsigned integers.
@@ -51,25 +54,32 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
5154 uint32_t res_y = __builtin_amdgcn_perm (v_odd_high, v_odd_low, mask_odd);
5255
5356 return make_int2 (res_x, res_y);
54- #elif defined(__CUDA_ARCH__)
55- uint32_t v1, v2, v3, v4, mask;
56- const uint32_t *values = (const uint32_t *)table;
57+ #elif !defined(GGML_USE_MUSA)
58+ // CUDA does not have an instruction for selecting bytes with 4 bit indices.
59+ // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
60+ const uint32_t * table32 = (const uint32_t *) table;
61+
62+ // __byte_perm selects bytes based on the lower 16 bits in its third argument.
63+ // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
64+ // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
65+ // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
66+ uint32_t tmp[2 ];
67+ const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888 ) >> 1 ));
68+ #pragma unroll
69+ for (uint32_t i = 0 ; i < 2 ; ++i) {
70+ const uint32_t shift = 16 * i;
5771
58- mask = (0x32103210 | ((q4 & 0x88888888 ) >> 1 ));
59- // Perform lookups in the lower half of the table (indices 0-7).
60- v1 = __byte_perm (values[0 ], values[1 ], q4);
61- // Perform lookups in the upper half of the table (indices 8-15).
62- v2 = __byte_perm (values[2 ], values[3 ], q4);
63- // Select between the low and high results based on the MSB of each index nibble.
64- v3 = __byte_perm (v1, v2, mask);
65- // Same for the upper part of q4.
66- v1 = __byte_perm (values[0 ], values[1 ], q4 >> 16 );
67- v2 = __byte_perm (values[2 ], values[3 ], q4 >> 16 );
68- v4 = __byte_perm (v1, v2, mask >> 16 );
69-
70- // Mix the results to get the final int2.
71- return make_int2 (__byte_perm (v3, v4, 0x6420 ), __byte_perm (v3, v4, 0x7531 ));
72+ const uint32_t low = __byte_perm (table32[0 ], table32[1 ], q4 >> shift);
73+ const uint32_t high = __byte_perm (table32[2 ], table32[3 ], q4 >> shift);
74+ tmp[i] = __byte_perm (low, high, low_high_selection_indices >> shift);
75+ }
76+
77+ // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
78+ // However, for the result we need ints with all even/odd 4 bit indices in q4.
79+ // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
80+ return make_int2 (__byte_perm (tmp[0 ], tmp[1 ], 0x6420 ), __byte_perm (tmp[0 ], tmp[1 ], 0x7531 ));
7281#else
82+ // Generic implementation.
7383 const int q0_32 = (q4 >> 0 ) & 0x0F0F0F0F ;
7484 const int8_t * q0_8 = (const int8_t *) &q0_32;
7585 const char4 val0_8 = make_char4 (
0 commit comments