Skip to content

Commit 0775df7

Browse files
committed
Reduce mmq register use
1 parent c4711d8 commit 0775df7

File tree

2 files changed

+61
-165
lines changed

2 files changed

+61
-165
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 26 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@
1010
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
1111
#endif
1212

13-
#ifdef COOPMAT
14-
#extension GL_KHR_cooperative_matrix : enable
15-
#extension GL_KHR_memory_scope_semantics : enable
16-
#extension GL_KHR_shader_subgroup_basic : enable
17-
#endif
18-
1913
#ifdef MUL_MAT_ID
2014
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
2115
#endif
@@ -79,10 +73,6 @@ layout (constant_id = 10) const uint WARP = 32;
7973

8074
#define BK 32
8175

82-
#ifdef COOPMAT
83-
#define SHMEM_STRIDE (BK / 4 + 4)
84-
#endif
85-
8676
#define MMQ_SHMEM
8777

8878
#include "mul_mmq_shmem_types.glsl"
@@ -92,7 +82,7 @@ shared block_a_cache buf_a[BM];
9282
shared block_b_cache buf_b[BN];
9383
// Register cache
9484
block_a_cache cache_a[WMITER * TM];
95-
block_b_cache cache_b[TN];
85+
block_b_cache cache_b;
9686

9787
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
9888
#define LOAD_VEC_B 16
@@ -104,10 +94,6 @@ shared u16vec2 row_ids[4096];
10494

10595
#define NUM_WARPS (BLOCK_SIZE / WARP)
10696

107-
#ifdef COOPMAT
108-
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
109-
#endif
110-
11197
#include "mul_mmq_funcs.glsl"
11298

11399
void main() {
@@ -137,26 +123,12 @@ void main() {
137123
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
138124
const uint WSUBM = WM / WMITER;
139125
const uint WSUBN = WN / WNITER;
140-
141-
#ifdef COOPMAT
142-
const uint warp_i = gl_SubgroupID;
143-
144-
const uint tiw = gl_SubgroupInvocationID;
145-
146-
const uint cms_per_row = WM / TM;
147-
const uint cms_per_col = WN / TN;
148-
149-
const uint storestride = WARP / TM;
150-
const uint store_r = tiw % TM;
151-
const uint store_c = tiw / TM;
152-
#else
153126
const uint warp_i = gl_LocalInvocationID.x / WARP;
154127

155128
const uint tiw = gl_LocalInvocationID.x % WARP;
156129

157130
const uint tiwr = tiw % (WSUBM / TM);
158131
const uint tiwc = tiw / (WSUBM / TM);
159-
#endif
160132

161133
const uint warp_r = warp_i % (BM / WM);
162134
const uint warp_c = warp_i / (BM / WM);
@@ -207,26 +179,11 @@ void main() {
207179
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
208180
#endif
209181

210-
#ifdef COOPMAT
211-
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
212-
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
213-
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
214-
215-
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
216-
217-
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
182+
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2];
218183

219-
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
220-
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
184+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
185+
sums[i] = ACC_TYPE_VEC2(0.0f);
221186
}
222-
#else
223-
224-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
225-
226-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
227-
sums[i] = ACC_TYPE(0.0f);
228-
}
229-
#endif
230187

231188
for (uint block = start_k; block < end_k; block += BK) {
232189
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
@@ -267,38 +224,6 @@ void main() {
267224
pos_a_ib += 1;
268225
pos_b_ib += 1;
269226

270-
#ifdef COOPMAT
271-
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
272-
const uint ib_a = warp_r * WM + cm_row * TM;
273-
// Load from shared into cache
274-
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
275-
276-
// TODO: only cache values that are actually needed
277-
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
278-
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
279-
}
280-
281-
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
282-
const uint ib_b = warp_c * WN + cm_col * TN;
283-
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
284-
285-
// TODO: only cache values that are actually needed
286-
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
287-
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
288-
}
289-
290-
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
291-
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
292-
293-
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
294-
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
295-
}
296-
297-
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
298-
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
299-
}
300-
}
301-
#else
302227
// Load from shared into cache
303228
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
304229
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
@@ -312,24 +237,22 @@ void main() {
312237
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
313238
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
314239
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
315-
cache_b[cc].ds = buf_b[ib].ds;
240+
cache_b.ds = buf_b[ib].ds;
316241
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
317-
cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs];
242+
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
318243
}
319-
}
320244

321-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
322-
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
323-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
324-
const uint cache_a_idx = wsir * TM + cr;
325-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
245+
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
246+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
247+
const uint cache_a_idx = wsir * TM + cr * 2;
248+
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
326249

327-
sums[sums_idx] += mmq_dot_product(cache_a_idx, cc);
250+
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
251+
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
328252
}
329253
}
330254
}
331255
}
332-
#endif
333256

334257
barrier();
335258
}
@@ -341,54 +264,6 @@ void main() {
341264
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
342265
#endif
343266

344-
#ifdef COOPMAT
345-
#ifdef MUL_MAT_ID
346-
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
347-
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
348-
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
349-
350-
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
351-
const uint row_i = dc + cm_col * TN + col + store_c;
352-
if (row_i >= _ne1) break;
353-
354-
const u16vec2 row_idx = row_ids[row_i];
355-
356-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
357-
}
358-
}
359-
}
360-
#else
361-
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
362-
363-
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
364-
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
365-
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
366-
367-
if (is_aligned && is_in_bounds) {
368-
// Full coopMat is within bounds and stride_d is aligned with 16B
369-
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
370-
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
371-
} else if (is_in_bounds) {
372-
// Full coopMat is within bounds, but stride_d is not aligned
373-
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
374-
375-
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
376-
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
377-
}
378-
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
379-
// Partial coopMat is within bounds
380-
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
381-
382-
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
383-
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
384-
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
385-
}
386-
}
387-
}
388-
}
389-
}
390-
#endif // MUL_MAT_ID
391-
#else
392267
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
393268
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
394269

@@ -399,19 +274,27 @@ void main() {
399274
const uint row_i = dc_warp + cc;
400275
if (row_i >= _ne1) break;
401276

402-
const u16vec2 row_idx = row_ids[row_i];
277+
const u16vec2 row_idx = row_ids[row_i - ic * BN];
403278
#endif // MUL_MAT_ID
404-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
279+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
280+
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
405281
#ifdef MUL_MAT_ID
406-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
282+
if (dr_warp + 2 * cr < p.M) {
283+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
284+
}
285+
if (dr_warp + 2 * cr + 1 < p.M) {
286+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
287+
}
407288
#else
408-
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
409-
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
289+
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
290+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
291+
}
292+
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
293+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
410294
}
411295
#endif // MUL_MAT_ID
412296
}
413297
}
414298
}
415299
}
416-
#endif // COOPMAT
417300
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,21 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
6262
void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) {
6363
}
6464

65-
ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) {
65+
ACC_TYPE mmq_dot_product(const uint ib_a) {
6666
int32_t q_sum = 0;
6767
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
6868
const uint32_t vui = cache_a[ib_a].qs[iqs];
6969
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
7070
(vui >> 4) & 0x0F0F0F0F);
7171

72-
const int32_t qs_b0 = cache_b[ib_b].qs[iqs];
73-
const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4];
72+
const int32_t qs_b0 = cache_b.qs[iqs];
73+
const int32_t qs_b1 = cache_b.qs[iqs + 4];
7474

7575
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
7676
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
7777
}
7878

79-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1);
79+
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
8080
}
8181
#endif // MMQ_SHMEM
8282

@@ -140,7 +140,7 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
140140
}
141141
}
142142

143-
ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) {
143+
ACC_TYPE mmq_dot_product(const uint ib_a) {
144144
int32_t q_sum = 0;
145145
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
146146
const uint32_t vui = cache_a[ib_a].qs[iqs];
@@ -150,14 +150,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) {
150150
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
151151
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
152152

153-
const int32_t qs_b0 = cache_b[ib_b].qs[iqs];
154-
const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4];
153+
const int32_t qs_b0 = cache_b.qs[iqs];
154+
const int32_t qs_b1 = cache_b.qs[iqs + 4];
155155

156156
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
157157
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
158158
}
159159

160-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1);
160+
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
161161
}
162162
#endif // MMQ_SHMEM
163163
#endif
@@ -191,16 +191,16 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
191191
}
192192
}
193193

194-
ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) {
194+
ACC_TYPE mmq_dot_product(const uint ib_a) {
195195
int32_t q_sum = 0;
196196
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
197197
const int32_t qs_a = cache_a[ib_a].qs[iqs];
198-
const int32_t qs_b = cache_b[ib_b].qs[iqs];
198+
const int32_t qs_b = cache_b.qs[iqs];
199199

200200
q_sum += dotPacked4x8EXT(qs_a, qs_b);
201201
}
202202

203-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1);
203+
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
204204
}
205205
#endif // MMQ_SHMEM
206206
#endif
@@ -247,7 +247,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
247247
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
248248

249249
if (iqs == 0) {
250-
buf_a[buf_ib].scales = u8vec2(data_a[ib_k].scales[iqs_k / 4], data_a[ib_k].scales[iqs_k / 4 + 1]);
250+
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
251251
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
252252
}
253253
}
@@ -261,26 +261,39 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
261261
}
262262
}
263263

264-
ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) {
264+
ACC_TYPE mmq_dot_product(const uint ib_a) {
265265
int32_t sum_d = 0;
266266
int32_t sum_m = 0;
267267

268-
const i32vec2 scales = i32vec2(cache_a[ib_a].scales);
269-
i32vec2 scale_m = scales >> 4;
268+
uint8_t scale = cache_a[ib_a].scales[0];
269+
int32_t scale_m = int32_t(scale >> 4);
270270
scale_m |= scale_m << 8;
271271
scale_m |= scale_m << 16;
272272

273-
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
274-
const uint idx_half = iqs / 4;
275-
const uint qs_shift = (iqs % 4) * 2;
273+
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
274+
const uint qs_shift = iqs * 2;
275+
276+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303);
277+
278+
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
279+
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
280+
}
281+
282+
scale = cache_a[ib_a].scales[1];
283+
scale_m = int32_t(scale >> 4);
284+
scale_m |= scale_m << 8;
285+
scale_m |= scale_m << 16;
286+
287+
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
288+
const uint qs_shift = (iqs - 4) * 2;
276289

277-
const int32_t qs_a = int32_t((cache_a[ib_a].qs[idx_half] >> qs_shift) & 0x03030303);
290+
const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303);
278291

279-
sum_d += dotPacked4x8EXT(qs_a, cache_b[ib_b].qs[iqs]) * (scales[idx_half] & 0xF);
280-
sum_m += dotPacked4x8EXT(scale_m[idx_half], cache_b[ib_b].qs[iqs]);
292+
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
293+
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
281294
}
282295

283-
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b[ib_b].ds, 1);
296+
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
284297
}
285298
#endif // MMQ_SHMEM
286299
#endif

0 commit comments

Comments
 (0)