Skip to content

Commit 02c1813

Browse files
authored
Vulkan: Add Integer Dot Product mul_mat_vec shader for legacy quants (ggml-org#14903)
* vulkan: Add Integer Dot Product mul_mat_vec shader for legacy quants * vulkan: use subgroup operations for quantize_q8_1 shader * vulkan: add q8_1_x4 type with 128-bit alignment, use in mul_mat_vecq shader * vulkan: use q8_1_x4 blocks in mul_mmq shader * vulkan: do 8 calculations per invocation instead of 32 in mul_mat_vecq, similar to mul_mat_vec * vulkan: tune mul_mat_vecq performance for Intel * vulkan: fix quantizing issue when tensor is not divisible by 128 * vulkan: adapt integer dot mmv to mmv small m optimization (ggml-org#15355) * vulkan: allow all subgroup modes for mmv and mmvq * vulkan: use prealloc intermediate reuse for mmvq path * vulkan: tune mmvq for Intel, AMD GCN and Nvidia RTX 3090 * vulkan: adapt mmv quantize_y path to conditional sync logic * vulkan: disable q8_0 mmvq on Nvidia * vulkan: enable q8_0 on Nvidia pre-turing * fix prealloc sync condition * fix llvmpipe subgroup 8 issue
1 parent 77dee9d commit 02c1813

File tree

8 files changed

+559
-115
lines changed

8 files changed

+559
-115
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 262 additions & 91 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#extension GL_EXT_control_flow_attributes : enable
22
#extension GL_EXT_shader_16bit_storage : require
33
#extension GL_EXT_shader_8bit_storage : require
4-
#if USE_SUBGROUP_ADD
4+
5+
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
56
#extension GL_KHR_shader_subgroup_basic : require
67
#extension GL_KHR_shader_subgroup_arithmetic : require
78
#endif
@@ -12,10 +13,19 @@
1213

1314
#include "types.comp"
1415

16+
#ifndef MMQ
1517
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
18+
#else
19+
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
20+
#endif
21+
1622
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
23+
#ifdef B_TYPE_VEC2
1724
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
25+
#endif
26+
#ifdef B_TYPE_VEC4
1827
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
28+
#endif
1929

2030
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
2131
#ifdef MUL_MAT_ID
@@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
92102
layout (constant_id = 1) const uint NUM_ROWS = 1;
93103
layout (constant_id = 2) const uint NUM_COLS = 1;
94104

105+
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
106+
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
107+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
108+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
109+
temp[j][n] = subgroupAdd(temp[j][n]);
110+
}
111+
}
112+
113+
if (tid == 0) {
114+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
115+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
116+
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
117+
}
118+
}
119+
}
120+
}
121+
#else
95122
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
96123

97124
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
@@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
152179
}
153180
#endif
154181
}
182+
#endif
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
4+
#extension GL_EXT_integer_dot_product : require
5+
6+
#define MMQ
7+
#define B_TYPE block_q8_1_x4
8+
9+
#include "mul_mat_vec_base.comp"
10+
11+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
12+
13+
#define K_PER_ITER 8
14+
15+
#include "mul_mmq_funcs.comp"
16+
17+
uint a_offset, b_offset, d_offset;
18+
19+
int32_t cache_b_qs[2];
20+
vec2 cache_b_ds;
21+
22+
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
23+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
24+
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
25+
26+
// Preload data_b block
27+
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
28+
const uint b_qs_idx = tid % 4;
29+
const uint b_block_idx_outer = b_block_idx / 4;
30+
const uint b_block_idx_inner = b_block_idx % 4;
31+
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
32+
33+
#if QUANT_R == 2
34+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
35+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
36+
#else
37+
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
38+
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
39+
#endif
40+
41+
uint ibi = first_row*p.ncols;
42+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
43+
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
44+
ibi += p.ncols;
45+
46+
int32_t q_sum = 0;
47+
#if QUANT_R == 2
48+
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
49+
q_sum += dotPacked4x8EXT(data_a_qs.x,
50+
cache_b_qs[0]);
51+
q_sum += dotPacked4x8EXT(data_a_qs.y,
52+
cache_b_qs[1]);
53+
#else
54+
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
55+
q_sum += dotPacked4x8EXT(data_a_qs,
56+
cache_b_qs[0]);
57+
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
58+
q_sum += dotPacked4x8EXT(data_a_qs,
59+
cache_b_qs[1]);
60+
#endif
61+
62+
#if QUANT_AUXF == 1
63+
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
64+
#else
65+
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
66+
#endif
67+
}
68+
}
69+
}
70+
71+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
72+
const uint tid = gl_LocalInvocationID.x;
73+
74+
get_offsets(a_offset, b_offset, d_offset);
75+
a_offset /= QUANT_K;
76+
b_offset /= QUANT_K_Q8_1;
77+
78+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
79+
80+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
81+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
82+
temp[j][n] = FLOAT_TYPE(0.0f);
83+
}
84+
}
85+
86+
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
87+
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
88+
num_iters++;
89+
}
90+
int unroll_count = 4;
91+
uint unrolled_iters = num_iters & ~(unroll_count - 1);
92+
93+
uint i = 0;
94+
while (i < unrolled_iters) {
95+
// Manually partially unroll the loop
96+
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
97+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
98+
i++;
99+
}
100+
}
101+
102+
unroll_count = 2;
103+
unrolled_iters = num_iters & ~(unroll_count - 1);
104+
105+
#if K_PER_ITER == 2
106+
if ((p.ncols & 1) != 0 &&
107+
unrolled_iters == num_iters &&
108+
unrolled_iters > 0) {
109+
unrolled_iters -= unroll_count;
110+
}
111+
#endif
112+
113+
while (i < unrolled_iters) {
114+
// Manually partially unroll the loop
115+
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
116+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
117+
i++;
118+
}
119+
}
120+
while (i < num_iters) {
121+
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
122+
i++;
123+
}
124+
125+
reduce_result(temp, d_offset, first_row, num_rows, tid);
126+
}
127+
128+
void main() {
129+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
130+
131+
// do NUM_ROWS at a time, unless there aren't enough remaining rows
132+
if (first_row + NUM_ROWS <= p.stride_d) {
133+
compute_outputs(first_row, NUM_ROWS);
134+
} else {
135+
if (first_row >= p.stride_d) {
136+
return;
137+
}
138+
compute_outputs(first_row, p.stride_d - first_row);
139+
}
140+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
2828
#if defined(A_TYPE_PACKED32)
2929
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3030
#endif
31-
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31+
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
3232
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3333

3434
#ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
9898
#endif
9999

100100
#define LOAD_VEC_A (4 * QUANT_R)
101-
#define LOAD_VEC_B 4
101+
#define LOAD_VEC_B 16
102102

103103
#ifdef MUL_MAT_ID
104104
shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270270
const uint iqs = idx & 0x7;
271271
#else
272272
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273+
const uint ib_outer = ib / 4;
274+
const uint ib_inner = ib % 4;
275+
273276
const uint iqs = loadr_b;
274277
#endif
275278

276279
const uint buf_ib = loadc_b + l;
277280

278281
if (iqs == 0) {
279-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280283
}
281-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282289
}
283290

284291
barrier();
@@ -349,7 +356,7 @@ void main() {
349356
cache_b_qs[cc * (BK / 4) + idx_k]);
350357
}
351358

352-
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
359+
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
353360
}
354361
}
355362
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
1616
(vui >> 4) & 0x0F0F0F0F);
1717
}
1818

19-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
20-
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
19+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
20+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
2121
}
2222
#endif
2323

@@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
2929
(vui >> 4) & 0x0F0F0F0F);
3030
}
3131

32-
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
33-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
32+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
33+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
3434
}
3535
#endif
3636

@@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
5050
return i32vec2(v0, v1);
5151
}
5252

53-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
54-
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
53+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
54+
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
5555
}
5656
#endif
5757

@@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
6969
return i32vec2(v0, v1);
7070
}
7171

72-
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
73-
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
72+
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
73+
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
7474
}
7575
#endif
7676

@@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
8181
data_a[ib].qs[iqs * 2 + 1]));
8282
}
8383

84-
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
84+
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
8585
return ACC_TYPE(float(q_sum) * da * dsb.x);
8686
}
8787
#endif

0 commit comments

Comments
 (0)