|
| 1 | +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require |
| 2 | +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require |
| 3 | +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require |
| 4 | + |
| 5 | +#include "types.glsl" |
| 6 | + |
| 7 | +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) |
| 8 | +FLOAT_TYPE get_dm(uint ib) { |
| 9 | + return FLOAT_TYPE(data_a[ib].d); |
| 10 | +} |
| 11 | +#endif |
| 12 | + |
| 13 | +#if defined(DATA_A_MXFP4) |
| 14 | +FLOAT_TYPE get_dm(uint ib) { |
| 15 | + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); |
| 16 | +} |
| 17 | +#endif |
| 18 | + |
| 19 | +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) |
| 20 | +FLOAT_TYPE_VEC2 get_dm(uint ib) { |
| 21 | + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); |
| 22 | +} |
| 23 | +#endif |
| 24 | + |
| 25 | +#if defined(DATA_A_Q2_K) |
| 26 | +FLOAT_TYPE_VEC2 get_dm(uint ib) { |
| 27 | + const uint ib_k = ib / 8; |
| 28 | + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); |
| 29 | +} |
| 30 | +#endif |
| 31 | + |
| 32 | +// Each iqs value maps to a 32-bit integer |
| 33 | +#if defined(DATA_A_Q4_0) |
| 34 | +// 2-byte loads for Q4_0 blocks (18 bytes) |
| 35 | +i32vec2 repack(uint ib, uint iqs) { |
| 36 | + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], |
| 37 | + data_a_packed16[ib].qs[iqs * 2 + 1]); |
| 38 | + const uint32_t vui = pack32(quants); |
| 39 | + return i32vec2( vui & 0x0F0F0F0F, |
| 40 | + (vui >> 4) & 0x0F0F0F0F); |
| 41 | +} |
| 42 | + |
| 43 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { |
| 44 | + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); |
| 45 | +} |
| 46 | +#endif |
| 47 | + |
| 48 | +#if defined(DATA_A_Q4_1) |
| 49 | +// 4-byte loads for Q4_1 blocks (20 bytes) |
| 50 | +i32vec2 repack(uint ib, uint iqs) { |
| 51 | + const uint32_t vui = data_a_packed32[ib].qs[iqs]; |
| 52 | + return i32vec2( vui & 0x0F0F0F0F, |
| 53 | + (vui >> 4) & 0x0F0F0F0F); |
| 54 | +} |
| 55 | + |
| 56 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { |
| 57 | + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); |
| 58 | +} |
| 59 | +#endif |
| 60 | + |
| 61 | +#if defined(DATA_A_Q5_0) |
| 62 | +// 2-byte loads for Q5_0 blocks (22 bytes) |
| 63 | +i32vec2 repack(uint ib, uint iqs) { |
| 64 | + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], |
| 65 | + data_a_packed16[ib].qs[iqs * 2 + 1]); |
| 66 | + const uint32_t vui = pack32(quants); |
| 67 | + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); |
| 68 | + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) |
| 69 | + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) |
| 70 | + |
| 71 | + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) |
| 72 | + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) |
| 73 | + |
| 74 | + return i32vec2(v0, v1); |
| 75 | +} |
| 76 | + |
| 77 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { |
| 78 | + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); |
| 79 | +} |
| 80 | +#endif |
| 81 | + |
| 82 | +#if defined(DATA_A_Q5_1) |
| 83 | +// 4-byte loads for Q5_1 blocks (24 bytes) |
| 84 | +i32vec2 repack(uint ib, uint iqs) { |
| 85 | + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], |
| 86 | + data_a_packed16[ib].qs[iqs * 2 + 1]); |
| 87 | + const uint32_t vui = pack32(quants); |
| 88 | + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); |
| 89 | + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) |
| 90 | + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) |
| 91 | + |
| 92 | + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) |
| 93 | + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) |
| 94 | + |
| 95 | + return i32vec2(v0, v1); |
| 96 | +} |
| 97 | + |
| 98 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { |
| 99 | + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); |
| 100 | +} |
| 101 | +#endif |
| 102 | + |
| 103 | +#if defined(DATA_A_Q8_0) |
| 104 | +// 2-byte loads for Q8_0 blocks (34 bytes) |
| 105 | +int32_t repack(uint ib, uint iqs) { |
| 106 | + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], |
| 107 | + data_a_packed16[ib].qs[iqs * 2 + 1])); |
| 108 | +} |
| 109 | + |
| 110 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { |
| 111 | + return ACC_TYPE(float(q_sum) * da * dsb.x); |
| 112 | +} |
| 113 | +#endif |
| 114 | + |
| 115 | +#if defined(DATA_A_MXFP4) |
| 116 | +// 1-byte loads for mxfp4 blocks (17 bytes) |
| 117 | +i32vec2 repack(uint ib, uint iqs) { |
| 118 | + const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], |
| 119 | + data_a[ib].qs[iqs * 4 + 1], |
| 120 | + data_a[ib].qs[iqs * 4 + 2], |
| 121 | + data_a[ib].qs[iqs * 4 + 3])); |
| 122 | + |
| 123 | + return i32vec2( quants & 0x0F0F0F0F, |
| 124 | + (quants >> 4) & 0x0F0F0F0F); |
| 125 | +} |
| 126 | + |
| 127 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { |
| 128 | + return ACC_TYPE(da * dsb.x * float(q_sum)); |
| 129 | +} |
| 130 | +#endif |
| 131 | + |
| 132 | +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) |
| 133 | +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_divisor) { |
| 134 | + int32_t q_sum = 0; |
| 135 | +#if QUANT_R == 2 |
| 136 | + const i32vec2 data_a_qs = repack(ib_a, iqs); |
| 137 | + q_sum += dotPacked4x8EXT(data_a_qs.x, |
| 138 | + cache_b_qs[0]); |
| 139 | + q_sum += dotPacked4x8EXT(data_a_qs.y, |
| 140 | + cache_b_qs[1]); |
| 141 | +#else |
| 142 | + int32_t data_a_qs = repack(ib_a, iqs * 2); |
| 143 | + q_sum += dotPacked4x8EXT(data_a_qs, |
| 144 | + cache_b_qs[0]); |
| 145 | + data_a_qs = repack(ib_a, iqs * 2 + 1); |
| 146 | + q_sum += dotPacked4x8EXT(data_a_qs, |
| 147 | + cache_b_qs[1]); |
| 148 | +#endif |
| 149 | + |
| 150 | + return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, sum_divisor); |
| 151 | +} |
| 152 | +#endif |
| 153 | + |
| 154 | +#if defined(DATA_A_Q2_K) |
| 155 | +// 4-byte loads for Q2_K blocks (84 bytes) |
| 156 | +int32_t repack(uint ib, uint iqs) { |
| 157 | + const uint ib_k = ib / 8; |
| 158 | + const uint iqs_k = (ib % 8) * 8 + iqs; |
| 159 | + |
| 160 | + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); |
| 161 | + const uint qs_shift = ((iqs_k % 32) / 8) * 2; |
| 162 | + |
| 163 | + return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); |
| 164 | +} |
| 165 | + |
| 166 | +uint8_t get_scale(uint ib, uint iqs) { |
| 167 | + const uint ib_k = ib / 8; |
| 168 | + const uint iqs_k = (ib % 8) * 8 + iqs; |
| 169 | + |
| 170 | + return data_a[ib_k].scales[iqs_k / 4]; |
| 171 | +} |
| 172 | + |
| 173 | +ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { |
| 174 | + return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); |
| 175 | +} |
| 176 | +#endif |
| 177 | + |
| 178 | +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) |
| 179 | +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) |
| 180 | +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { |
| 181 | + return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); |
| 182 | +} |
| 183 | +#endif |
0 commit comments