Skip to content

Commit 3aa74df

Browse files
committed
add q2_k mmvq
1 parent da201d6 commit 3aa74df

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4040

4141
uint ibi = first_row*p.ncols;
4242
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
43-
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
43+
const uint a_block_idx = (ibi + col)/32 + a_offset;
4444
ibi += p.ncols;
4545

4646
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ i32vec2 repack(uint ib, uint iqs) {
4040
(vui >> 4) & 0x0F0F0F0F);
4141
}
4242

43-
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
44-
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
43+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
44+
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
4545
}
4646
#endif
4747

@@ -53,8 +53,8 @@ i32vec2 repack(uint ib, uint iqs) {
5353
(vui >> 4) & 0x0F0F0F0F);
5454
}
5555

56-
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
57-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
56+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
57+
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
5858
}
5959
#endif
6060

@@ -74,8 +74,8 @@ i32vec2 repack(uint ib, uint iqs) {
7474
return i32vec2(v0, v1);
7575
}
7676

77-
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
78-
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
77+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
78+
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
7979
}
8080
#endif
8181

@@ -95,8 +95,8 @@ i32vec2 repack(uint ib, uint iqs) {
9595
return i32vec2(v0, v1);
9696
}
9797

98-
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
99-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
98+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
99+
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
100100
}
101101
#endif
102102

@@ -107,8 +107,8 @@ int32_t repack(uint ib, uint iqs) {
107107
data_a_packed16[ib].qs[iqs * 2 + 1]));
108108
}
109109

110-
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111-
return ACC_TYPE(float(q_sum) * da * dsb.x);
110+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111+
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
112112
}
113113
#endif
114114

@@ -127,8 +127,8 @@ i32vec2 repack(uint ib, uint iqs) {
127127
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
128128
}
129129

130-
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
131-
return ACC_TYPE(da * dsb.x * float(q_sum) * 0.5);
130+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
131+
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
132132
}
133133
#endif
134134

@@ -157,14 +157,15 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
157157

158158
#if defined(DATA_A_Q2_K)
159159
// 4-byte loads for Q2_K blocks (84 bytes)
160-
int32_t repack(uint ib, uint iqs) {
160+
i32vec2 repack2(uint ib, uint iqs) {
161161
const uint ib_k = ib / 8;
162162
const uint iqs_k = (ib % 8) * 8 + iqs;
163163

164164
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
165165
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
166166

167-
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
167+
return i32vec2((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
168+
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303);
168169
}
169170

170171
uint8_t get_scale(uint ib, uint iqs) {
@@ -178,25 +179,24 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
178179
int32_t sum_d = 0;
179180
int32_t sum_m = 0;
180181

181-
const int32_t qs_a0 = repack(ib_a, iqs * 2);
182-
const int32_t qs_a1 = repack(ib_a, iqs * 2 + 1);
182+
const i32vec2 qs_a = repack2(ib_a, iqs * 2);
183183
const uint8_t scale = get_scale(ib_a, iqs * 2);
184+
const vec2 dm = vec2(get_dm(ib_a));
184185
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
185186

186-
sum_d += dotPacked4x8EXT(qs_a0, cache_b_qs[0]) * (scale & 0xF);
187+
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
187188
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
188189

189-
sum_d += dotPacked4x8EXT(qs_a1, cache_b_qs[1]) * (scale & 0xF);
190+
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
190191
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
191192

192-
const vec2 dm = get_dm(ib_a);
193-
return ACC_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m) / 4));
193+
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
194194
}
195195
#endif
196196

197197
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
198198
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
199-
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
200-
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
199+
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
200+
return FLOAT_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
201201
}
202202
#endif

0 commit comments

Comments
 (0)