55
66#ifdef USE_SUBGROUPS
77#extension GL_KHR_shader_subgroup_basic : require
8- #extension GL_KHR_shader_subgroup_clustered : require
8+ #extension GL_KHR_shader_subgroup_arithmetic : require
99
1010#define INVOCATION_ID gl_SubgroupInvocationID.x
1111#else
1212#define INVOCATION_ID gl_LocalInvocationID.x
1313#endif
1414
1515#define MMQ
16- #define B_TYPE block_q8_1_x4_packed128
16+ #define B_TYPE block_q8_1_x4
1717
1818#include "mul_mat_vec_base.comp"
1919
2020layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2121
22- #define K_PER_ITER 32
23-
24- const uint GROUP_SIZE = 8;
25- const uint GROUPS_PER_WARP = (BLOCK_SIZE / GROUP_SIZE);
22+ #define K_PER_ITER 8
2623
2724#include "mul_mmq_funcs.comp"
2825
29- uint a_offset, b_offset, d_offset, y_offset ;
26+ uint a_offset, b_offset, d_offset;
3027
31- #ifdef USE_SUBGROUPS
32- void reduce_result_grouped(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid_in_group) {
33- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
34- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
35- temp[j][n] = subgroupClusteredAdd(temp[j][n], GROUP_SIZE);
36- }
37- }
38-
39- if (tid_in_group == 0) {
40- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
41- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
42- data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
43- }
44- }
45- }
46- }
47- #else
48- void reduce_result_grouped(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid_in_group) {
49- const uint tid = INVOCATION_ID;
50- // sum up partial sums and write back result
51- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
52- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
53- tmpsh[j][n][tid] = temp[j][n];
54- }
55- }
56- barrier();
57- if (tid_in_group < 4) {
58- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
59- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
60- tmpsh[j][n][tid] += tmpsh[j][n][tid + 4];
61- }
62- }
63- }
64- barrier();
65- if (tid_in_group < 2) {
66- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
67- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
68- tmpsh[j][n][tid] += tmpsh[j][n][tid + 2];
69- }
70- }
71- }
72- barrier();
73- if (tid_in_group == 0) {
74- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
75- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
76- data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][tid] + tmpsh[j][n][tid + 1]);
77- }
78- }
79- }
80- }
81- #endif
82-
83- ivec4 cache_b_qs[2];
28+ int32_t cache_b_qs[2];
8429vec2 cache_b_ds;
8530
86- void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid_in_group , const uint i) {
31+ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid , const uint i) {
8732 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
88- const uint col = i*GROUP_SIZE + K_PER_ITER*tid_in_group ;
33+ const uint col = i*BLOCK_SIZE + tid*K_PER_ITER ;
8934
9035 // Preload data_b block
9136 const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
37+ const uint b_qs_idx = tid % 4;
9238 const uint b_block_idx_outer = b_block_idx / 4;
9339 const uint b_block_idx_inner = b_block_idx % 4;
9440 cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
95- cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2];
96- cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2 + 1];
41+
42+ #if QUANT_R == 2
43+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
44+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
45+ #else
46+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
47+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
48+ #endif
9749
9850 uint ibi = first_row*p.ncols;
9951 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -102,71 +54,36 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
10254
10355 int32_t q_sum = 0;
10456#if QUANT_R == 2
105- i32vec2 data_a_qs = repack(a_block_idx, 0);
106- q_sum += dotPacked4x8EXT(data_a_qs.x,
107- cache_b_qs[0].x);
108- q_sum += dotPacked4x8EXT(data_a_qs.y,
109- cache_b_qs[1].x);
110- data_a_qs = repack(a_block_idx, 1);
111- q_sum += dotPacked4x8EXT(data_a_qs.x,
112- cache_b_qs[0].y);
113- q_sum += dotPacked4x8EXT(data_a_qs.y,
114- cache_b_qs[1].y);
115- data_a_qs = repack(a_block_idx, 2);
57+ const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
11658 q_sum += dotPacked4x8EXT(data_a_qs.x,
117- cache_b_qs[0].z );
59+ cache_b_qs[0]);
11860 q_sum += dotPacked4x8EXT(data_a_qs.y,
119- cache_b_qs[1].z);
120- data_a_qs = repack(a_block_idx, 3);
121- q_sum += dotPacked4x8EXT(data_a_qs.x,
122- cache_b_qs[0].w);
123- q_sum += dotPacked4x8EXT(data_a_qs.y,
124- cache_b_qs[1].w);
61+ cache_b_qs[1]);
12562#else
126- int32_t data_a_qs = repack(a_block_idx, 0);
127- q_sum += dotPacked4x8EXT(data_a_qs,
128- cache_b_qs[0].x);
129- data_a_qs = repack(a_block_idx, 1);
130- q_sum += dotPacked4x8EXT(data_a_qs,
131- cache_b_qs[0].y);
132- data_a_qs = repack(a_block_idx, 2);
133- q_sum += dotPacked4x8EXT(data_a_qs,
134- cache_b_qs[0].z);
135- data_a_qs = repack(a_block_idx, 3);
63+ int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
13664 q_sum += dotPacked4x8EXT(data_a_qs,
137- cache_b_qs[0].w );
138- data_a_qs = repack(a_block_idx, 4 );
65+ cache_b_qs[0]);
66+ data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1 );
13967 q_sum += dotPacked4x8EXT(data_a_qs,
140- cache_b_qs[1].x);
141- data_a_qs = repack(a_block_idx, 5);
142- q_sum += dotPacked4x8EXT(data_a_qs,
143- cache_b_qs[1].y);
144- data_a_qs = repack(a_block_idx, 6);
145- q_sum += dotPacked4x8EXT(data_a_qs,
146- cache_b_qs[1].z);
147- data_a_qs = repack(a_block_idx, 7);
148- q_sum += dotPacked4x8EXT(data_a_qs,
149- cache_b_qs[1].w);
68+ cache_b_qs[1]);
15069#endif
15170
15271#if QUANT_AUXF == 1
153- temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds);
72+ temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4 );
15473#else
155- temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds);
74+ temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4 );
15675#endif
15776 }
15877 }
15978}
16079
16180void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
162- const uint tid_in_group = INVOCATION_ID % GROUP_SIZE ;
81+ const uint tid = INVOCATION_ID;
16382
16483 get_offsets(a_offset, b_offset, d_offset);
16584 a_offset /= QUANT_K;
16685 b_offset /= QUANT_K_Q8_1;
16786
168- y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
169-
17087 FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
17188
17289 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
@@ -175,8 +92,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
17592 }
17693 }
17794
178- uint num_iters = p.ncols / (K_PER_ITER * GROUP_SIZE );
179- if (num_iters * K_PER_ITER * GROUP_SIZE + K_PER_ITER*tid_in_group < p.ncols) {
95+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE );
96+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
18097 num_iters++;
18198 }
18299 int unroll_count = 4;
@@ -186,7 +103,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
186103 while (i < unrolled_iters) {
187104 // Manually partially unroll the loop
188105 [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
189- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
106+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
190107 i++;
191108 }
192109 }
@@ -205,22 +122,20 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
205122 while (i < unrolled_iters) {
206123 // Manually partially unroll the loop
207124 [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
208- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
125+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
209126 i++;
210127 }
211128 }
212129 while (i < num_iters) {
213- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
130+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
214131 i++;
215132 }
216133
217- reduce_result_grouped (temp, d_offset, first_row, num_rows, tid_in_group );
134+ reduce_result (temp, d_offset, first_row, num_rows, tid );
218135}
219136
220137void main() {
221- const uint group_id = INVOCATION_ID / GROUP_SIZE;
222- // 8 threads work together on a NUM_ROWS * NUM_COLS block/slice
223- const uint first_row = NUM_ROWS * (GROUPS_PER_WARP * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z) + group_id);
138+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
224139
225140 // do NUM_ROWS at a time, unless there aren't enough remaining rows
226141 if (first_row + NUM_ROWS <= p.stride_d) {
0 commit comments