44#extension GL_EXT_shader_16bit_storage : enable
55#extension GL_EXT_shader_explicit_arithmetic_types : enable
66
7- #define Q80_Q40_BLOCK_SIZE 32
8- #define N_THREADS 256
7+ #define N_THREADS 64
8+ #define TILE_SIZE_X 2
9+ #define TILE_SIZE_D 16
910
10- #define N_OUTPUTS_PER_ITER 64
11- #define N_THREADS_PER_OUTPUT (N_THREADS / N_OUTPUTS_PER_ITER)
11+ #define Q80_Q40_BLOCK_SIZE 32
1212
1313layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1414
@@ -34,80 +34,98 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
3434layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
3535layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
3636
37- shared uint sharedStart ;
38- shared uint sharedEnd ;
37+ shared uint sharedXSlice ;
38+ shared uint sharedXRest ;
3939shared uint sharedInputOffset;
4040shared uint sharedInputSizeX;
4141shared uint sharedOutputOffset;
42- shared uint sharedInputSizeXPerGroup ;
43- shared float16_t sums[N_THREADS];
42+ shared uint sharedD ;
43+ shared float16_t sums[N_THREADS * TILE_SIZE_D ];
4444
4545void main() {
4646 const uint threadIndex = gl_LocalInvocationID.x;
4747
4848 if (threadIndex == 0) {
49- const uint nWorkGroups = gl_NumWorkGroups.z;
5049 const uint batchIndex = gl_WorkGroupID.y;
5150 const uint workGroupIndex = gl_WorkGroupID.z;
5251
5352 const BatchInfo info = infos[batchIndex];
53+
54+ const uint xTiles = info.inputSizeX / TILE_SIZE_X;
55+ sharedXSlice = xTiles / N_THREADS;
56+ sharedXRest = xTiles % N_THREADS;
57+
5458 sharedInputOffset = info.inputOffset;
5559 sharedInputSizeX = info.inputSizeX;
5660 sharedOutputOffset = info.outputOffset;
57- sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_OUTPUT - 1) / N_THREADS_PER_OUTPUT;
58-
59- const uint ySlice = info.outputSizeX / nWorkGroups;
60- const uint yRest = info.outputSizeX % nWorkGroups;
61- sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
62- sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
61+ sharedD = TILE_SIZE_D * workGroupIndex;
6362 }
6463
6564 barrier();
6665 memoryBarrierShared();
6766
68- const uint dEnd = sharedEnd;
67+ const uint xSlice = sharedXSlice;
68+ const uint xRest = sharedXRest;
69+ const uint xStart = (threadIndex * xSlice + min(threadIndex, xRest)) * TILE_SIZE_X;
70+ const uint xEnd = xStart + (xSlice + (threadIndex < xRest ? 1 : 0)) * TILE_SIZE_X;
71+
6972 const uint inputOffset = sharedInputOffset;
7073 const uint inputSizeX = sharedInputSizeX;
7174 const uint outputOffset = sharedOutputOffset;
72- const uint inputSizeXPerGroup = sharedInputSizeXPerGroup ;
75+ const uint d = sharedD ;
7376
74- const uint dGroup = threadIndex / N_THREADS_PER_OUTPUT;
75- const uint iGroup = threadIndex % N_THREADS_PER_OUTPUT;
76- const uint iStart = inputSizeXPerGroup * iGroup;
77- const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
77+ f16vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4];
7878
79- for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += N_OUTPUTS_PER_ITER) {
80- const uint d = dBatch + dGroup;
81- if (d >= dEnd) {
82- break;
83- }
79+ for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
80+ sums[threadIndex * TILE_SIZE_D + dt] = float16_t(0.0f);
81+ }
82+
83+ for (uint i = xStart; i < xEnd; i += TILE_SIZE_X) {
84+ [[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) {
85+ const uint xi = inputOffset + i + it;
86+ const float16_t xScale = x[xi].d;
87+ [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
88+ xTemp[j] = f16vec4(
89+ x[xi].qs[j * 2],
90+ x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2],
91+ x[xi].qs[j * 2 + 1],
92+ x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2]
93+ );
94+ }
8495
85- float16_t sum = float16_t(0.0f);
86- for (uint i = iStart; i < iEnd; i++) {
87- const uint xi = inputOffset + i;
88- const uint wi = d * inputSizeX + i;
89- const float16_t scale = x[xi].d * weight[wi].d;
90- [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
91- sum += (
92- float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
93- float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
94- ) * scale;
96+ [[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
97+ const uint wi = (d + dt) * inputSizeX + (i + it);
98+ const BlockQ40 wBlock = weight[wi];
99+
100+ float16_t s = float16_t(0);
101+ [[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
102+ uint w0 = wBlock.qs[j * 2];
103+ uint w1 = wBlock.qs[j * 2 + 1];
104+ ivec4 w = ivec4(
105+ w0 & 0xFu,
106+ w0 >> 4,
107+ w1 & 0xFu,
108+ w1 >> 4
109+ ) - ivec4(8);
110+ s += dot(xTemp[j], f16vec4(w));
111+ }
112+ sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d;
95113 }
96114 }
97- sums[threadIndex] = sum;
115+ }
98116
99- barrier();
100- memoryBarrierShared();
117+ barrier();
118+ memoryBarrierShared();
101119
102- [[unroll]] for (uint i = N_THREADS_PER_OUTPUT / 2; i > 0; i >>= 1) {
103- if (iGroup < i)
104- sums[threadIndex] += sums[threadIndex + i];
105- barrier();
106- }
107- if (iGroup == 0) {
108- y[outputOffset + d] = float(sums[threadIndex]);
120+ [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
121+ for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
122+ if (threadIndex < i) {
123+ sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
124+ }
109125 }
110-
111126 barrier();
112127 }
128+ for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
129+ y[outputOffset + d + dt] = float(sums[dt]);
130+ }
113131}
0 commit comments