Skip to content

Commit 50dfb13

Browse files
authored
feat: vulkan optimization. (#196)
1 parent afa6297 commit 50dfb13

9 files changed

+72
-59
lines changed

src/nn/nn-vulkan-test.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ void testMatmul_F32_F32_F32() {
529529
}
530530

531531
void testMatmul_Q80_Q40_F32() {
532-
#define MATMUL_Q80_Q40_N 64
533-
#define MATMUL_Q80_Q40_D 96
532+
#define MATMUL_Q80_Q40_N 512
533+
#define MATMUL_Q80_Q40_D 512
534534
execute(
535535
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
536536
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_Q80, N_BATCHES, MATMUL_Q80_Q40_N));
@@ -552,19 +552,19 @@ void testMatmul_Q80_Q40_F32() {
552552
constexpr NnUint weightSize = MATMUL_Q80_Q40_N * MATMUL_Q80_Q40_D;
553553
constexpr NnUint weightBlocks = weightSize / Q40_BLOCK_SIZE;
554554

555-
float x[xSize];
556-
float weight[weightSize];
557-
NnBlockQ40 weightQ40[weightBlocks];
555+
std::unique_ptr<float[]> x(new float[xSize]);
556+
std::unique_ptr<float[]> weight(new float[weightSize]);
557+
std::unique_ptr<NnBlockQ40[]> weightQ40(new NnBlockQ40[weightBlocks]);
558558

559559
for (NnUint i = 0; i < xSize; i++)
560-
x[i] = i * 0.01f;
560+
x[i] = i * 0.001f;
561561
for (NnUint i = 0; i < weightSize; i++)
562-
weight[i] = i * 0.001f;
562+
weight[i] = i * 0.0001f;
563563

564-
quantizeF32toQ80(x, xPipe, xSize, 1, 0);
565-
quantizeF32toQ40(weight, weightQ40, weightSize, 1, 0);
564+
quantizeF32toQ80(x.get(), xPipe, xSize, 1, 0);
565+
quantizeF32toQ40(weight.get(), weightQ40.get(), weightSize, 1, 0);
566566

567-
executor->loadWeight("matmul", 0, weightBlocks * sizeof(NnBlockQ40), (NnByte *)weightQ40);
567+
executor->loadWeight("matmul", 0, weightBlocks * sizeof(NnBlockQ40), (NnByte *)weightQ40.get());
568568

569569
// act
570570
executor->forward();
@@ -576,8 +576,8 @@ void testMatmul_Q80_Q40_F32() {
576576
for (NnUint n = 0; n < MATMUL_Q80_Q40_N; n++)
577577
sum += x[b * MATMUL_Q80_Q40_N + n] * weight[d * MATMUL_Q80_Q40_N + n];
578578
const NnUint p = b * MATMUL_Q80_Q40_D + d;
579-
const float change = (yPipe[p] - sum) / sum;
580-
assertFloat(p, change, 0.0, 0.04f);
579+
const float tolerance = sum * 0.025f;
580+
assertFloat(p, yPipe[p], sum, tolerance);
581581
}
582582
}
583583
printOk("testMatmul_Q80_Q40_F32");

src/nn/vulkan/cast-forward-f32-q80.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#extension GL_EXT_shader_explicit_arithmetic_types : enable
66

77
#define Q80_BLOCK_SIZE 32
8-
#define N_THREADS 64
8+
#define N_THREADS 256
99

1010
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1111

@@ -61,7 +61,7 @@ void main() {
6161
const uint yiOffset = yOffset + i;
6262

6363
float amax = 0.0;
64-
for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
64+
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
6565
const float v = abs(x[xiOffset + j]);
6666
amax = max(amax, v);
6767
}

src/nn/vulkan/matmul-forward-q80-q40-f32.comp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#define Q80_Q40_BLOCK_SIZE 32
88
#define N_THREADS 256
99

10-
#define GROUP_SIZE 64
11-
#define N_THREADS_PER_GROUP (N_THREADS / GROUP_SIZE)
10+
#define N_OUTPUTS_PER_ITER 64
11+
#define N_THREADS_PER_OUTPUT (N_THREADS / N_OUTPUTS_PER_ITER)
1212

1313
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1414

@@ -50,13 +50,14 @@ void main() {
5050
const uint batchIndex = gl_WorkGroupID.y;
5151
const uint workGroupIndex = gl_WorkGroupID.z;
5252

53-
sharedInputOffset = infos[batchIndex].inputOffset;
54-
sharedInputSizeX = infos[batchIndex].inputSizeX;
55-
sharedOutputOffset = infos[batchIndex].outputOffset;
56-
sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_GROUP - 1) / N_THREADS_PER_GROUP;
53+
const BatchInfo info = infos[batchIndex];
54+
sharedInputOffset = info.inputOffset;
55+
sharedInputSizeX = info.inputSizeX;
56+
sharedOutputOffset = info.outputOffset;
57+
sharedInputSizeXPerGroup = (sharedInputSizeX + N_THREADS_PER_OUTPUT - 1) / N_THREADS_PER_OUTPUT;
5758

58-
const uint ySlice = infos[batchIndex].outputSizeX / nWorkGroups;
59-
const uint yRest = infos[batchIndex].outputSizeX % nWorkGroups;
59+
const uint ySlice = info.outputSizeX / nWorkGroups;
60+
const uint yRest = info.outputSizeX % nWorkGroups;
6061
sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
6162
sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
6263
}
@@ -70,12 +71,12 @@ void main() {
7071
const uint outputOffset = sharedOutputOffset;
7172
const uint inputSizeXPerGroup = sharedInputSizeXPerGroup;
7273

73-
const uint dGroup = threadIndex / N_THREADS_PER_GROUP;
74-
const uint iGroup = threadIndex % N_THREADS_PER_GROUP;
74+
const uint dGroup = threadIndex / N_THREADS_PER_OUTPUT;
75+
const uint iGroup = threadIndex % N_THREADS_PER_OUTPUT;
7576
const uint iStart = inputSizeXPerGroup * iGroup;
7677
const uint iEnd = min(iStart + inputSizeXPerGroup, inputSizeX);
7778

78-
for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += GROUP_SIZE) {
79+
for (uint dBatch = sharedStart; dBatch < dEnd; dBatch += N_OUTPUTS_PER_ITER) {
7980
const uint d = dBatch + dGroup;
8081
if (d >= dEnd) {
8182
break;
@@ -85,19 +86,20 @@ void main() {
8586
for (uint i = iStart; i < iEnd; i++) {
8687
const uint xi = inputOffset + i;
8788
const uint wi = d * inputSizeX + i;
89+
const float16_t scale = x[xi].d * weight[wi].d;
8890
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 2; j++) {
8991
sum += (
9092
float16_t(x[xi].qs[j]) * (float16_t(weight[wi].qs[j] & 0xF) - float16_t(8.0f)) +
9193
float16_t(x[xi].qs[j + Q80_Q40_BLOCK_SIZE / 2]) * (float16_t(weight[wi].qs[j] >> 4) - float16_t(8.0f))
92-
) * x[xi].d * weight[wi].d;
94+
) * scale;
9395
}
9496
}
9597
sums[threadIndex] = sum;
9698

9799
barrier();
98100
memoryBarrierShared();
99101

100-
[[unroll]] for (uint i = N_THREADS_PER_GROUP / 2; i > 0; i >>= 1) {
102+
[[unroll]] for (uint i = N_THREADS_PER_OUTPUT / 2; i > 0; i >>= 1) {
101103
if (iGroup < i)
102104
sums[threadIndex] += sums[threadIndex + i];
103105
barrier();

src/nn/vulkan/merge-add-forward-f32-f32.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 64
3+
#define N_THREADS 256
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

src/nn/vulkan/merge-add-forward-q80-f32.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#extension GL_EXT_shader_explicit_arithmetic_types : enable
66

77
#define Q80_BLOCK_SIZE 32
8-
#define N_THREADS 64
8+
#define N_THREADS 256
99

1010
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
1111

src/nn/vulkan/mul-forward-f32-f32.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 64
3+
#define N_THREADS 256
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

src/nn/vulkan/multi-head-att-forward-f32-f32.comp

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#version 450
22

3-
#define N_THREADS 64
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define N_THREADS 256
46

57
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
68

@@ -36,7 +38,8 @@ layout(binding = 8) buffer attBufferBuffer { float att[]; };
3638

3739
shared BatchInfo sharedInfo;
3840
shared uint position;
39-
shared float sharedSum;
41+
shared float sharedMaxScore;
42+
shared float temp[N_THREADS];
4043

4144
void main() {
4245
const uint threadIndex = gl_LocalInvocationID.x;
@@ -45,7 +48,7 @@ void main() {
4548

4649
const uint kvMul = nHeads / nKvHeads;
4750
const uint headIndex = h / kvMul;
48-
const float headSizeRoot = sqrt(float(headSize));
51+
const float invHeadSizeRoot = 1.0 / sqrt(float(headSize));
4952

5053

5154
if (threadIndex == 0) {
@@ -61,56 +64,65 @@ void main() {
6164
const uint kvOffset = headIndex * headSize;
6265
const uint yOffset = sharedInfo.outputOffset + h * headSize;
6366

67+
float ms = -1e10f;
6468
for (uint p = threadIndex; p <= position; p += N_THREADS) {
65-
float score = 0.0;
6669
const uint kOffset = kvOffset + p * kvDim0;
70+
71+
float score = 0.0;
6772
for (uint i = 0; i < headSize; i++) {
6873
score += query[qOffset + i] * keyCache[kOffset + i];
6974
}
70-
att[attOffset + p] = score / headSizeRoot;
75+
score *= invHeadSizeRoot;
76+
ms = max(ms, score);
77+
att[attOffset + p] = score;
7178
}
7279

80+
temp[threadIndex] = ms;
81+
7382
barrier();
83+
memoryBarrierShared();
7484

75-
// softmax
76-
if (threadIndex == 0) {
77-
// TODO: split into multiple threads
78-
float maxScore = att[attOffset];
79-
for (uint p = 1; p <= position; p++) {
80-
maxScore = max(maxScore, att[attOffset + p]);
81-
}
85+
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
86+
if (threadIndex < i)
87+
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
88+
barrier();
89+
}
8290

83-
float sum = 0.0;
84-
for (uint p = 0; p <= position; p++) {
85-
float v = exp(att[attOffset + p] - maxScore);
86-
att[attOffset + p] = v;
87-
sum += v;
88-
}
89-
sharedSum = sum;
91+
if (threadIndex == 0) {
92+
sharedMaxScore = temp[0];
9093
}
9194

9295
barrier();
96+
memoryBarrierShared();
9397

94-
const float sum = sharedSum;
98+
const float maxScore = sharedMaxScore;
9599

100+
float s = 0.0;
96101
for (uint p = threadIndex; p <= position; p += N_THREADS) {
97-
att[attOffset + p] /= sum;
102+
float v = exp(att[attOffset + p] - maxScore);
103+
att[attOffset + p] = v;
104+
s += v;
98105
}
99106

100-
// return
101-
for (uint i = threadIndex; i < headSize; i += N_THREADS) {
102-
y[yOffset + i] = 0.0;
107+
temp[threadIndex] = s;
108+
barrier();
109+
110+
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
111+
if (threadIndex < i)
112+
temp[threadIndex] += temp[threadIndex + i];
113+
barrier();
103114
}
104115

105-
barrier();
116+
const float yScale = 1.0 / temp[0];
106117

107118
for (uint i = threadIndex; i < headSize; i += N_THREADS) {
108119
float sum = 0.0;
120+
const uint vOffset = kvOffset + i;
109121
for (uint p = 0; p <= position; p += 1) {
110122
const float a = att[attOffset + p];
111-
const float v = valueCache[kvOffset + p * kvDim0 + i];
123+
const float v = valueCache[vOffset + p * kvDim0];
112124
sum += v * a;
113125
}
114-
y[yOffset + i] = sum;
126+
y[yOffset + i] = sum * yScale;
115127
}
116128
}

src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#define N_THREADS 64
3+
#define N_THREADS 256
44

55
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
66

src/nn/vulkan/silu-forward-f32-f32.comp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ shared uint sharedYOffset;
2121

2222
void main() {
2323
const uint threadIndex = gl_LocalInvocationID.x;
24-
const uint batchIndex = gl_GlobalInvocationID.y;
2524

2625
if (threadIndex == 0) {
2726
const uint nWorkGroups = gl_NumWorkGroups.z;

0 commit comments

Comments
 (0)