Skip to content

Commit 40d75d9

Browse files
committed
Add mxfp4 mmq, enable MMQ MUL_MAT_ID
1 parent 84cb48c commit 40d75d9

File tree

12 files changed

+290
-132
lines changed

12 files changed

+290
-132
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 125 additions & 26 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
437437
#if defined(DATA_A_MXFP4)
438438
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
439439
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
440-
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
440+
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
441441
}
442442
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
443443
vec2 v0 = dequantize(ib, iqs, a_offset);

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
680680
uint32_t qs = bl.block.qs[iqs];
681681
qs >>= shift;
682682
qs &= 0xF;
683-
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
683+
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
684684
return ret;
685685
}
686686
#endif

ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void main() {
2626
const float d = e8m0_to_fp32(data_a[ib].e);
2727

2828
[[unroll]] for (uint l = 0; l < 8; ++l) {
29-
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
30-
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
29+
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
30+
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
3131
}
3232
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
120120

121121
#define NUM_WARPS (BLOCK_SIZE / WARP)
122122

123-
#ifdef MUL_MAT_ID
124-
shared u16vec2 row_ids[BN];
125-
uint _ne1;
126-
127-
#ifdef MUL_MAT_ID_USE_SUBGROUPS
128-
shared uvec4 ballots_sh[NUM_WARPS];
129-
130-
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
131-
_ne1 = 0;
132-
uint num_elements = p.nei1 * p.nei0;
133-
uint nei0shift = findLSB(p.nei0);
134-
135-
uint ids[16];
136-
uint iter = 0;
137-
138-
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
139-
// prefetch up to 16 elements
140-
if (iter == 0) {
141-
[[unroll]] for (uint k = 0; k < 16; ++k) {
142-
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
143-
bool in_range = i < num_elements;
144-
uint ii1;
145-
if (nei0_is_pow2) {
146-
ii1 = i >> nei0shift;
147-
} else {
148-
ii1 = i / p.nei0;
149-
}
150-
uint ii0 = i - ii1 * p.nei0;
151-
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
152-
}
153-
}
154-
uint i = j + gl_LocalInvocationIndex;
155-
bool in_range = i < num_elements;
156-
uint ii1;
157-
if (nei0_is_pow2) {
158-
ii1 = i >> nei0shift;
159-
} else {
160-
ii1 = i / p.nei0;
161-
}
162-
uint ii0 = i - ii1 * p.nei0;
163-
uint id = ids[iter++];
164-
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
165-
166-
ballots_sh[gl_SubgroupID] = ballot;
167-
barrier();
168-
169-
uint subgroup_base = 0;
170-
uint total = 0;
171-
for (uint k = 0; k < gl_NumSubgroups; ++k) {
172-
if (k == gl_SubgroupID) {
173-
subgroup_base = total;
174-
}
175-
total += subgroupBallotBitCount(ballots_sh[k]);
176-
}
177-
barrier();
178-
179-
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
180-
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
181-
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
182-
}
183-
_ne1 += total;
184-
iter &= 15;
185-
if (_ne1 >= (ic + 1) * BN) {
186-
break;
187-
}
188-
}
189-
barrier();
190-
}
191-
#endif // MUL_MAT_ID_USE_SUBGROUPS
192-
#endif // MUL_MAT_ID
193-
194123
#ifdef COOPMAT
195124
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
196125
#endif
197126

127+
#include "mul_mm_id_funcs.glsl"
198128
#include "mul_mm_funcs.glsl"
199129

200130
void main() {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
468468
const uint ib = idx / 8;
469469
const uint iqs = (idx & 0x07) * 2;
470470

471-
const float d = e8m0_to_fp32(data_a[ib].e);
471+
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
472472
const uint vui = uint(data_a[ib].qs[iqs]);
473473
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
474474

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#ifdef MUL_MAT_ID
2+
shared u16vec2 row_ids[BN];
3+
uint _ne1;
4+
5+
#ifdef MUL_MAT_ID_USE_SUBGROUPS
6+
shared uvec4 ballots_sh[NUM_WARPS];
7+
8+
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
9+
_ne1 = 0;
10+
uint num_elements = p.nei1 * p.nei0;
11+
uint nei0shift = findLSB(p.nei0);
12+
13+
uint ids[16];
14+
uint iter = 0;
15+
16+
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
17+
// prefetch up to 16 elements
18+
if (iter == 0) {
19+
[[unroll]] for (uint k = 0; k < 16; ++k) {
20+
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
21+
bool in_range = i < num_elements;
22+
uint ii1;
23+
if (nei0_is_pow2) {
24+
ii1 = i >> nei0shift;
25+
} else {
26+
ii1 = i / p.nei0;
27+
}
28+
uint ii0 = i - ii1 * p.nei0;
29+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
30+
}
31+
}
32+
uint i = j + gl_LocalInvocationIndex;
33+
bool in_range = i < num_elements;
34+
uint ii1;
35+
if (nei0_is_pow2) {
36+
ii1 = i >> nei0shift;
37+
} else {
38+
ii1 = i / p.nei0;
39+
}
40+
uint ii0 = i - ii1 * p.nei0;
41+
uint id = ids[iter++];
42+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
43+
44+
ballots_sh[gl_SubgroupID] = ballot;
45+
barrier();
46+
47+
uint subgroup_base = 0;
48+
uint total = 0;
49+
for (uint k = 0; k < gl_NumSubgroups; ++k) {
50+
if (k == gl_SubgroupID) {
51+
subgroup_base = total;
52+
}
53+
total += subgroupBallotBitCount(ballots_sh[k]);
54+
}
55+
barrier();
56+
57+
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
58+
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
59+
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
60+
}
61+
_ne1 += total;
62+
iter &= 15;
63+
if (_ne1 >= (ic + 1) * BN) {
64+
break;
65+
}
66+
}
67+
barrier();
68+
}
69+
#endif // MUL_MAT_ID_USE_SUBGROUPS
70+
#endif // MUL_MAT_ID

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13+
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
14+
#extension GL_KHR_shader_subgroup_basic : enable
15+
#extension GL_KHR_shader_subgroup_ballot : enable
16+
#endif
17+
1318
#ifdef MUL_MAT_ID
1419
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
1520
#endif
@@ -91,13 +96,9 @@ block_b_cache cache_b;
9196
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
9297
#define LOAD_VEC_B 16
9398

94-
// TODO: Recheck if this can work with mul_mat_id
95-
#ifdef MUL_MAT_ID
96-
shared u16vec2 row_ids[4096];
97-
#endif // MUL_MAT_ID
98-
9999
#define NUM_WARPS (BLOCK_SIZE / WARP)
100100

101+
#include "mul_mm_id_funcs.glsl"
101102
#include "mul_mmq_funcs.glsl"
102103

103104
void main() {
@@ -146,17 +147,27 @@ void main() {
146147
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
147148

148149
#ifdef MUL_MAT_ID
149-
uint _ne1 = 0;
150-
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
151-
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
150+
#ifdef MUL_MAT_ID_USE_SUBGROUPS
151+
if (bitCount(p.nei0) == 1) {
152+
load_row_ids(expert_idx, true, ic);
153+
} else {
154+
load_row_ids(expert_idx, false, ic);
155+
}
156+
#else
157+
_ne1 = 0;
158+
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
159+
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
152160
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
153-
row_ids[_ne1] = u16vec2(ii0, ii1);
161+
if (_ne1 >= ic * BN) {
162+
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
163+
}
154164
_ne1++;
155165
}
156166
}
157167
}
158168

159169
barrier();
170+
#endif
160171

161172
// Workgroup has no work
162173
if (ic * BN >= _ne1) return;
@@ -203,15 +214,12 @@ void main() {
203214
const uint buf_ib = loadc_b + l;
204215

205216
#ifdef MUL_MAT_ID
206-
const u16vec2 row_idx = row_ids[ic * BN + buf_ib];
207-
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
208-
const uint ib = idx / 8;
209-
const uint iqs = idx & 0x7;
217+
const u16vec2 row_idx = row_ids[buf_ib];
218+
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
210219
#else
211220
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
212-
213-
const uint iqs = loadr_b;
214221
#endif
222+
const uint iqs = loadr_b;
215223

216224
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
217225
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
5959
}
6060
}
6161

62-
void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) {
63-
}
64-
6562
ACC_TYPE mmq_dot_product(const uint ib_a) {
6663
int32_t q_sum = 0;
6764
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
@@ -205,6 +202,61 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
205202
#endif // MMQ_SHMEM
206203
#endif
207204

205+
#if defined(DATA_A_MXFP4)
206+
// 1-byte loads for mxfp4 blocks (17 bytes)
207+
i32vec2 repack(uint ib, uint iqs) {
208+
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
209+
data_a[ib].qs[iqs * 4 + 1],
210+
data_a[ib].qs[iqs * 4 + 2],
211+
data_a[ib].qs[iqs * 4 + 3]));
212+
213+
return i32vec2( quants & 0x0F0F0F0F,
214+
(quants >> 4) & 0x0F0F0F0F);
215+
}
216+
217+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
218+
return ACC_TYPE(da * dsb.x * float(q_sum));
219+
}
220+
221+
#ifdef MMQ_SHMEM
222+
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
223+
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
224+
data_a[ib].qs[iqs * 4 + 1],
225+
data_a[ib].qs[iqs * 4 + 2],
226+
data_a[ib].qs[iqs * 4 + 3]));
227+
228+
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
229+
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
230+
231+
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
232+
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
233+
234+
if (iqs == 0) {
235+
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
236+
}
237+
}
238+
239+
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
240+
cache_a[reg_ib].d = buf_a[buf_ib].d;
241+
242+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
243+
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
244+
}
245+
}
246+
247+
ACC_TYPE mmq_dot_product(const uint ib_a) {
248+
int32_t q_sum = 0;
249+
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
250+
const int32_t qs_a = cache_a[ib_a].qs[iqs];
251+
252+
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
253+
}
254+
255+
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
256+
}
257+
#endif // MMQ_SHMEM
258+
#endif
259+
208260
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
209261
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
210262
#if defined(DATA_A_Q2_K)

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ struct block_a_cache {
3232
int32_t qs[32/4];
3333
FLOAT_TYPE dm;
3434
};
35+
#elif defined(DATA_A_MXFP4)
36+
#define QUANT_R_MMQ 2
37+
struct block_a_cache {
38+
int32_t qs[8];
39+
FLOAT_TYPE d;
40+
};
3541
#elif defined(DATA_A_Q2_K)
3642
#define QUANT_R_MMQ 4
3743
struct block_a_cache {

0 commit comments

Comments
 (0)