Skip to content

Commit e7cab55

Browse files
author
Stefan Savic
committed
vulkan: ADD ACC_TYPE_VEC2 optimization for MMQ on PR ggml-org#16536
Signed-off-by: Stefan Savic <[email protected]>
1 parent 07c0ee4 commit e7cab55

File tree

2 files changed

+133
-82
lines changed

2 files changed

+133
-82
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ void main() {
183183
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
184184
#endif
185185

186-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
186+
ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2];
187187

188-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
189-
sums[i] = ACC_TYPE(0.0f);
188+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
189+
sums[i] = ACC_TYPE_VEC2(0.0f);
190190
}
191191

192192
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
@@ -240,10 +240,9 @@ void main() {
240240
block_b_to_registers(ib);
241241

242242
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
243-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
244-
const uint cache_a_idx = wsir * TM + cr;
245-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
246-
243+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
244+
const uint cache_a_idx = wsir * TM + cr * 2;
245+
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
247246
sums[sums_idx] += mmq_dot_product(cache_a_idx);
248247
}
249248
}
@@ -273,15 +272,21 @@ void main() {
273272

274273
const u16vec2 row_idx = row_ids[row_i - ic * BN];
275274
#endif // MUL_MAT_ID
276-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
277-
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
275+
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
276+
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
278277
#ifdef MUL_MAT_ID
279-
if (dr_warp + cr < p.M) {
280-
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
278+
if (dr_warp + 2 * cr < p.M) {
279+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
280+
}
281+
if (dr_warp + 2 * cr + 1 < p.M) {
282+
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
281283
}
282284
#else
283-
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
284-
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
285+
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
286+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
287+
}
288+
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
289+
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
285290
}
286291
#endif // MUL_MAT_ID
287292
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl

Lines changed: 115 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,25 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
6262
void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) {
6363
}
6464

65-
ACC_TYPE mmq_dot_product(const uint ib_a) {
66-
int32_t q_sum = 0;
65+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
66+
i32vec2 q_sum = i32vec2(0);
6767
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
68-
const uint32_t vui = cache_a[ib_a].qs[iqs];
69-
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
70-
(vui >> 4) & 0x0F0F0F0F);
68+
const u32vec2 vui = u32vec2(cache_a[ib_a ].qs[iqs],
69+
cache_a[ib_a + 1].qs[iqs]);
70+
const i32vec4 qs_a = i32vec4(vui.x & 0x0F0F0F0F, (vui.x >> 4) & 0x0F0F0F0F,
71+
vui.y & 0x0F0F0F0F, (vui.y >> 4) & 0x0F0F0F0F);
72+
const i32vec2 qs_b = i32vec2(cache_b.qs[iqs],
73+
cache_b.qs[iqs + 4]);
74+
75+
q_sum.x += dotPacked4x8EXT(qs_a.x, qs_b.x);
76+
q_sum.y += dotPacked4x8EXT(qs_a.z, qs_b.x);
77+
q_sum.x += dotPacked4x8EXT(qs_a.y, qs_b.y);
78+
q_sum.y += dotPacked4x8EXT(qs_a.w, qs_b.y);
7179

72-
const int32_t qs_b0 = cache_b.qs[iqs];
73-
const int32_t qs_b1 = cache_b.qs[iqs + 4];
74-
75-
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
76-
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
7780
}
7881

79-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
82+
return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1),
83+
mul_q8_1(q_sum.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
8084
}
8185
#endif // MMQ_SHMEM
8286

@@ -140,24 +144,35 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
140144
}
141145
}
142146

143-
ACC_TYPE mmq_dot_product(const uint ib_a) {
144-
int32_t q_sum = 0;
147+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
148+
i32vec2 q_sum = i32vec2(0);
145149
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
146-
const uint32_t vui = cache_a[ib_a].qs[iqs];
147-
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
148-
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
149-
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
150-
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
151-
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
152-
153-
const int32_t qs_b0 = cache_b.qs[iqs];
154-
const int32_t qs_b1 = cache_b.qs[iqs + 4];
155-
156-
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
157-
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
150+
const i32vec2 qs_b = i32vec2(cache_b.qs[iqs ],
151+
cache_b.qs[iqs + 4]);
152+
153+
const u32vec2 vui = u32vec2(cache_a[ib_a ].qs[iqs],
154+
cache_a[ib_a + 1].qs[iqs]);
155+
156+
const int32_t qh_0 = int32_t(cache_a[ib_a ].qh >> (4 * iqs));
157+
const int32_t qh_1 = int32_t(cache_a[ib_a + 1].qh >> (4 * iqs));
158+
159+
const i32vec2 qs_a0 = i32vec2(int32_t(vui.x & 0x0F0F0F0F) | ((qh_0 & 0xF) * 0x02040810) & 0x10101010, // (0,1,2,3) -> (4,12,20,28)
160+
int32_t((vui.x >> 4) & 0x0F0F0F0F) | (((qh_0 >> 16) & 0xF) * 0x02040810) & 0x10101010); // (16,17,18,19) -> (4,12,20,28)
161+
162+
163+
const i32vec2 qs_a1 = i32vec2(int32_t(vui.y & 0x0F0F0F0F) | ((qh_1 & 0xF) * 0x02040810) & 0x10101010,
164+
int32_t((vui.y >> 4) & 0x0F0F0F0F) | (((qh_1 >> 16) & 0xF) * 0x02040810) & 0x10101010);
165+
166+
q_sum.x += dotPacked4x8EXT(qs_a0.x, qs_b.x);
167+
q_sum.y += dotPacked4x8EXT(qs_a1.x, qs_b.x);
168+
169+
q_sum.x += dotPacked4x8EXT(qs_a0.y, qs_b.y);
170+
q_sum.y += dotPacked4x8EXT(qs_a1.y, qs_b.y);
171+
158172
}
159173

160-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
174+
return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1),
175+
mul_q8_1(q_sum.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
161176
}
162177
#endif // MMQ_SHMEM
163178
#endif
@@ -191,16 +206,16 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
191206
}
192207
}
193208

194-
ACC_TYPE mmq_dot_product(const uint ib_a) {
195-
int32_t q_sum = 0;
209+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
210+
i32vec2 q_sum = i32vec2(0);
196211
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
197-
const int32_t qs_a = cache_a[ib_a].qs[iqs];
198212
const int32_t qs_b = cache_b.qs[iqs];
199-
200-
q_sum += dotPacked4x8EXT(qs_a, qs_b);
213+
q_sum.x += dotPacked4x8EXT(cache_a[ib_a ].qs[iqs], qs_b);
214+
q_sum.y += dotPacked4x8EXT(cache_a[ib_a + 1].qs[iqs], qs_b);
201215
}
202216

203-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
217+
return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1),
218+
mul_q8_1(q_sum.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
204219
}
205220
#endif // MMQ_SHMEM
206221
#endif
@@ -261,21 +276,34 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
261276
}
262277
}
263278

264-
ACC_TYPE mmq_dot_product(const uint ib_a) {
265-
int32_t sum_d = 0;
266-
int32_t sum_m = 0;
279+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
280+
i32vec2 sum_d = i32vec2(0);
281+
i32vec2 sum_m = i32vec2(0);
267282

268283
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
269-
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
270-
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
271-
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
284+
const u8vec2 scale = u8vec2(cache_a[ib_a ].scales[iqs / 4],
285+
cache_a[ib_a + 1].scales[iqs / 4]);
286+
287+
const i32vec2 scale_m = i32vec2(int32_t(scale.x >> 4) * 0x01010101,
288+
int32_t(scale.y >> 4) * 0x01010101); // Duplicate 8-bit value across 32-bits.
289+
290+
const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303,
291+
(cache_a[ib_a + 1].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
292+
293+
const int32_t qs_b = cache_b.qs[iqs];
294+
sum_d.x += dotPacked4x8EXT(qs_a.x, qs_b) * (scale.x & 0xF);
295+
sum_d.y += dotPacked4x8EXT(qs_a.y, qs_b) * (scale.y & 0xF);
296+
297+
sum_m.x += dotPacked4x8EXT(scale_m.x, qs_b);
298+
sum_m.y += dotPacked4x8EXT(scale_m.y, qs_b);
272299

273-
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
274-
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
275300
}
276301

277-
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
302+
return ACC_TYPE_VEC2(mul_q8_1(sum_d.x, sum_m.x, cache_a[ib_a ].dm, cache_b.ds, 1),
303+
mul_q8_1(sum_d.y, sum_m.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
304+
278305
}
306+
279307
#endif // MMQ_SHMEM
280308
#endif
281309

@@ -321,27 +349,34 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
321349
}
322350
}
323351

324-
ACC_TYPE mmq_dot_product(const uint ib_a) {
325-
float result = 0.0;
326-
int32_t q_sum = 0;
352+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
353+
vec2 result = vec2(0.0);
354+
i32vec2 q_sum = i32vec2(0);
327355

328356
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
329357
// Subtract 4 from the quants to correct the 3rd bit offset
330-
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
358+
const i32vec2 qs_a = i32vec2(pack32(unpack8(int32_t((cache_a[ib_a ].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)),
359+
pack32(unpack8(int32_t((cache_a[ib_a + 1].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)));
331360

332-
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
361+
q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
362+
q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
333363
}
334-
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
335-
q_sum = 0;
364+
result.x += float(cache_a[ib_a ].d_scales[0]) * float(q_sum.x);
365+
result.y += float(cache_a[ib_a + 1].d_scales[0]) * float(q_sum.y);
366+
q_sum = i32vec2(0);
336367

337368
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
338-
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
369+
const i32vec2 qs_a = i32vec2(pack32(unpack8(int32_t((cache_a[ib_a ].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)),
370+
pack32(unpack8(int32_t((cache_a[ib_a + 1].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)));
339371

340-
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
372+
q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
373+
q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
341374
}
342-
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
375+
result.x += float(cache_a[ib_a ].d_scales[1]) * float(q_sum.x);
376+
result.y += float(cache_a[ib_a + 1].d_scales[1]) * float(q_sum.y);
343377

344-
return ACC_TYPE(cache_b.ds.x * result);
378+
return ACC_TYPE_VEC2(cache_b.ds.x * result.x,
379+
cache_b.ds.x * result.y);
345380
}
346381
#endif // MMQ_SHMEM
347382
#endif
@@ -398,20 +433,24 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
398433
}
399434
}
400435

401-
ACC_TYPE mmq_dot_product(const uint ib_a) {
402-
int32_t q_sum = 0;
436+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
437+
i32vec2 q_sum = i32vec2(0);
403438

404439
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
405440
#if defined(DATA_A_Q4_K)
406-
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
441+
const i32vec2 qs_a = i32vec2((cache_a[ib_a ].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F,
442+
(cache_a[ib_a + 1].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
407443
#else // defined(DATA_A_Q5_K)
408-
const int32_t qs_a = cache_a[ib_a].qs[iqs];
444+
const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
445+
cache_a[ib_a + 1].qs[iqs]);
409446
#endif
410447

411-
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
448+
q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
449+
q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
412450
}
413451

414-
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
452+
return ACC_TYPE_VEC2(mul_q8_1(q_sum.x, cache_a[ib_a ].dm, cache_b.ds, 1),
453+
mul_q8_1(q_sum.y, cache_a[ib_a + 1].dm, cache_b.ds, 1));
415454
}
416455
#endif // MMQ_SHMEM
417456
#endif
@@ -475,26 +514,33 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
475514
}
476515
}
477516

478-
ACC_TYPE mmq_dot_product(const uint ib_a) {
479-
float result = 0.0;
480-
int32_t q_sum = 0;
517+
ACC_TYPE_VEC2 mmq_dot_product(const uint ib_a) {
518+
vec2 result = vec2(0.0);
519+
i32vec2 q_sum = i32vec2(0);
481520

482521
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
483-
const int32_t qs_a = cache_a[ib_a].qs[iqs];
522+
const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
523+
cache_a[ib_a + 1].qs[iqs]);
484524

485-
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
525+
q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
526+
q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
486527
}
487-
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
488-
q_sum = 0;
528+
result.x += float(cache_a[ib_a ].d_scales[0]) * float(q_sum.x);
529+
result.y += float(cache_a[ib_a + 1].d_scales[0]) * float(q_sum.y);
530+
q_sum = i32vec2(0);
489531

490532
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
491-
const int32_t qs_a = cache_a[ib_a].qs[iqs];
533+
const i32vec2 qs_a = i32vec2(cache_a[ib_a ].qs[iqs],
534+
cache_a[ib_a + 1].qs[iqs]);
492535

493-
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
536+
q_sum.x += dotPacked4x8EXT(qs_a.x, cache_b.qs[iqs]);
537+
q_sum.y += dotPacked4x8EXT(qs_a.y, cache_b.qs[iqs]);
494538
}
495-
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
539+
result.x += float(cache_a[ib_a ].d_scales[1]) * float(q_sum.x);
540+
result.y += float(cache_a[ib_a + 1].d_scales[1]) * float(q_sum.y);
496541

497-
return ACC_TYPE(cache_b.ds.x * result);
542+
return ACC_TYPE_VEC2(cache_b.ds.x * result.x,
543+
cache_b.ds.x * result.y);
498544
}
499545
#endif // MMQ_SHMEM
500546
#endif

0 commit comments

Comments
 (0)