Skip to content

Commit aad4752

Browse files
committed
vulkan: split mul_mmq_funcs for mul_mat_vecq use
1 parent 31c511a commit aad4752

File tree

9 files changed

+240
-205
lines changed

9 files changed

+240
-205
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44

55
#include "types.glsl"
66

7-
#if defined(A_TYPE_PACKED16)
8-
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
9-
#endif
10-
#if defined(A_TYPE_PACKED32)
11-
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
12-
#endif
13-
147
#if defined(DATA_A_F32)
158
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
169
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ layout (push_constant) uniform parameter
1515
} p;
1616

1717
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
18+
#if defined(A_TYPE_PACKED16)
19+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
20+
#endif
21+
#if defined(A_TYPE_PACKED32)
22+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
23+
#endif
24+
1825
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
1926
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
2027

ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
1818
} p;
1919

2020
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
21+
#if defined(A_TYPE_PACKED16)
22+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
23+
#endif
24+
#if defined(A_TYPE_PACKED32)
25+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
26+
#endif
27+
2128
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
2229

2330
uint get_idx() {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44

55
#include "mul_mat_vec_base.glsl"
6+
#include "dequant_funcs.glsl"
67

78
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
89

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414
#include "types.glsl"
1515

16-
#ifndef MMQ
1716
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
18-
#else
19-
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
17+
#if defined(A_TYPE_PACKED16)
18+
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
19+
#endif
20+
#if defined(A_TYPE_PACKED32)
21+
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
2022
#endif
2123

2224
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
@@ -32,8 +34,6 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3234
layout (binding = 3) readonly buffer IDS {int data_ids[];};
3335
#endif
3436

35-
#include "dequant_funcs.glsl"
36-
3737
layout (push_constant) uniform parameter
3838
{
3939
uint ncols;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1212

1313
#define K_PER_ITER 8
1414

15-
#include "mul_mmq_funcs.glsl"
16-
1715
uint a_offset, b_offset, d_offset;
1816

1917
int32_t cache_b_qs[2];
2018
vec2 cache_b_ds;
2119

20+
#include "mul_mat_vecq_funcs.glsl"
21+
2222
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
2323
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
2424
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
@@ -43,27 +43,7 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4343
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
4444
ibi += p.ncols;
4545

46-
int32_t q_sum = 0;
47-
#if QUANT_R == 2
48-
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
49-
q_sum += dotPacked4x8EXT(data_a_qs.x,
50-
cache_b_qs[0]);
51-
q_sum += dotPacked4x8EXT(data_a_qs.y,
52-
cache_b_qs[1]);
53-
#else
54-
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
55-
q_sum += dotPacked4x8EXT(data_a_qs,
56-
cache_b_qs[0]);
57-
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
58-
q_sum += dotPacked4x8EXT(data_a_qs,
59-
cache_b_qs[1]);
60-
#endif
61-
62-
#if QUANT_AUXF == 1
63-
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
64-
#else
65-
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
66-
#endif
46+
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx, 4);
6747
}
6848
}
6949
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
4+
5+
#include "types.glsl"
6+
7+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
8+
FLOAT_TYPE get_dm(uint ib) {
9+
return FLOAT_TYPE(data_a[ib].d);
10+
}
11+
#endif
12+
13+
#if defined(DATA_A_MXFP4)
14+
FLOAT_TYPE get_dm(uint ib) {
15+
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
16+
}
17+
#endif
18+
19+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
20+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
21+
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
22+
}
23+
#endif
24+
25+
#if defined(DATA_A_Q2_K)
26+
FLOAT_TYPE_VEC2 get_dm(uint ib) {
27+
const uint ib_k = ib / 8;
28+
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
29+
}
30+
#endif
31+
32+
// Each iqs value maps to a 32-bit integer
33+
#if defined(DATA_A_Q4_0)
34+
// 2-byte loads for Q4_0 blocks (18 bytes)
35+
i32vec2 repack(uint ib, uint iqs) {
36+
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
37+
data_a_packed16[ib].qs[iqs * 2 + 1]);
38+
const uint32_t vui = pack32(quants);
39+
return i32vec2( vui & 0x0F0F0F0F,
40+
(vui >> 4) & 0x0F0F0F0F);
41+
}
42+
43+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
44+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
45+
}
46+
#endif
47+
48+
#if defined(DATA_A_Q4_1)
49+
// 4-byte loads for Q4_1 blocks (20 bytes)
50+
i32vec2 repack(uint ib, uint iqs) {
51+
const uint32_t vui = data_a_packed32[ib].qs[iqs];
52+
return i32vec2( vui & 0x0F0F0F0F,
53+
(vui >> 4) & 0x0F0F0F0F);
54+
}
55+
56+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
57+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
58+
}
59+
#endif
60+
61+
#if defined(DATA_A_Q5_0)
62+
// 2-byte loads for Q5_0 blocks (22 bytes)
63+
i32vec2 repack(uint ib, uint iqs) {
64+
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
65+
data_a_packed16[ib].qs[iqs * 2 + 1]);
66+
const uint32_t vui = pack32(quants);
67+
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
68+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
69+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
70+
71+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
72+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
73+
74+
return i32vec2(v0, v1);
75+
}
76+
77+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
78+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
79+
}
80+
#endif
81+
82+
#if defined(DATA_A_Q5_1)
83+
// 4-byte loads for Q5_1 blocks (24 bytes)
84+
i32vec2 repack(uint ib, uint iqs) {
85+
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
86+
data_a_packed16[ib].qs[iqs * 2 + 1]);
87+
const uint32_t vui = pack32(quants);
88+
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
89+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
90+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
91+
92+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
93+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
94+
95+
return i32vec2(v0, v1);
96+
}
97+
98+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
99+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
100+
}
101+
#endif
102+
103+
#if defined(DATA_A_Q8_0)
104+
// 2-byte loads for Q8_0 blocks (34 bytes)
105+
int32_t repack(uint ib, uint iqs) {
106+
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
107+
data_a_packed16[ib].qs[iqs * 2 + 1]));
108+
}
109+
110+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
111+
return ACC_TYPE(float(q_sum) * da * dsb.x);
112+
}
113+
#endif
114+
115+
#if defined(DATA_A_MXFP4)
116+
// 1-byte loads for mxfp4 blocks (17 bytes)
117+
i32vec2 repack(uint ib, uint iqs) {
118+
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
119+
data_a[ib].qs[iqs * 4 + 1],
120+
data_a[ib].qs[iqs * 4 + 2],
121+
data_a[ib].qs[iqs * 4 + 3]));
122+
123+
return i32vec2( quants & 0x0F0F0F0F,
124+
(quants >> 4) & 0x0F0F0F0F);
125+
}
126+
127+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
128+
return ACC_TYPE(da * dsb.x * float(q_sum));
129+
}
130+
#endif
131+
132+
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
133+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_divisor) {
134+
int32_t q_sum = 0;
135+
#if QUANT_R == 2
136+
const i32vec2 data_a_qs = repack(ib_a, iqs);
137+
q_sum += dotPacked4x8EXT(data_a_qs.x,
138+
cache_b_qs[0]);
139+
q_sum += dotPacked4x8EXT(data_a_qs.y,
140+
cache_b_qs[1]);
141+
#else
142+
int32_t data_a_qs = repack(ib_a, iqs * 2);
143+
q_sum += dotPacked4x8EXT(data_a_qs,
144+
cache_b_qs[0]);
145+
data_a_qs = repack(ib_a, iqs * 2 + 1);
146+
q_sum += dotPacked4x8EXT(data_a_qs,
147+
cache_b_qs[1]);
148+
#endif
149+
150+
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, sum_divisor);
151+
}
152+
#endif
153+
154+
#if defined(DATA_A_Q2_K)
155+
// 4-byte loads for Q2_K blocks (84 bytes)
156+
int32_t repack(uint ib, uint iqs) {
157+
const uint ib_k = ib / 8;
158+
const uint iqs_k = (ib % 8) * 8 + iqs;
159+
160+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
161+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
162+
163+
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
164+
}
165+
166+
uint8_t get_scale(uint ib, uint iqs) {
167+
const uint ib_k = ib / 8;
168+
const uint iqs_k = (ib % 8) * 8 + iqs;
169+
170+
return data_a[ib_k].scales[iqs_k / 4];
171+
}
172+
173+
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
174+
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
175+
}
176+
#endif
177+
178+
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
179+
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
180+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
181+
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
182+
}
183+
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
7878

7979
#define BK 32
8080

81-
#define MMQ_SHMEM
82-
8381
#include "mul_mmq_shmem_types.glsl"
8482

8583
#ifdef MUL_MAT_ID

0 commit comments

Comments
 (0)