Skip to content

Commit 2f1cdc2

Browse files
committed
feat: optimized shaders.
1 parent 6537d3a commit 2f1cdc2

File tree

7 files changed

+63
-87
lines changed

7 files changed

+63
-87
lines changed

src/nn/nn-vulkan-test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,16 @@ void testMultiheadAtt_F32_F32() {
662662
int main() {
663663
initQuants();
664664

665-
testRmsNorm_F32_F32_F32<3>();
665+
testRmsNorm_F32_F32_F32<4>();
666666
testRmsNorm_F32_F32_F32<1024>();
667-
testRmsNorm_F32_F32_F32<3191>();
667+
testRmsNorm_F32_F32_F32<3196>();
668668

669-
testSilu_F32_F32<3>();
669+
testSilu_F32_F32<4>();
670670
testSilu_F32_F32<32>();
671-
testSilu_F32_F32<101>();
671+
testSilu_F32_F32<104>();
672672

673673
testMul_F32_F32<32>();
674-
testMul_F32_F32<47>();
674+
testMul_F32_F32<48>();
675675

676676
testMergeAdd_F32_F32();
677677

@@ -686,7 +686,7 @@ int main() {
686686

687687
testCast_F32_F32<128>();
688688
testCast_F32_F32<32>();
689-
testCast_F32_F32<9>();
689+
testCast_F32_F32<8>();
690690

691691
testCast_F32_Q80<256>();
692692
testCast_F32_Q80<64>();

src/nn/nn-vulkan.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -539,13 +539,7 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
539539
} else {
540540
groupCount[2] = 32;
541541
}
542-
} else if (
543-
opConfig->code == OP_MUL ||
544-
opConfig->code == OP_SILU ||
545-
opConfig->code == OP_SHIFT
546-
)
547-
groupCount[2] = 32;
548-
else if (opConfig->code == OP_MATMUL) {
542+
} else if (opConfig->code == OP_MATMUL) {
549543
if (opConfig->weightSize.floatType == F_Q40) {
550544
// Must be synced with the shader
551545
constexpr NnUint tileSizeN = 2;
@@ -562,8 +556,17 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
562556
groupCount[2] = ((NnMultiHeadAttOpConfig *)opConfig->config)->nHeads0;
563557
else if (opConfig->code == OP_INV_RMS)
564558
groupCount[2] = ((NnInvRmsOpConfig *)opConfig->config)->nColumns;
565-
else if (opConfig->code == OP_RMS_NORM)
566-
groupCount[2] = ((NnRmsNormOpConfig *)opConfig->config)->nColumns;
559+
else if (
560+
opConfig->code == OP_EMBEDDING ||
561+
opConfig->code == OP_RMS_NORM ||
562+
opConfig->code == OP_MUL ||
563+
opConfig->code == OP_SILU ||
564+
opConfig->code == OP_SHIFT
565+
) {
566+
constexpr NnUint chunkSize = 4;
567+
assert(outputSize.x % chunkSize == 0);
568+
groupCount[2] = outputSize.x / chunkSize;
569+
}
567570
}
568571

569572
static std::vector<uint32_t> readShader(const char *fileName) {
Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define CHUNK_SIZE 4
46
#define N_BATCHES 32
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -20,23 +22,16 @@ layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
2022
shared uint sharedPosition;
2123

2224
void main() {
23-
const uint threadIndex = gl_LocalInvocationID.x;
24-
const uint batchIndex = gl_GlobalInvocationID.y;
25-
26-
if (threadIndex == 0) {
27-
sharedPosition = uint(x[batchIndex]);
28-
}
25+
const uint batchIndex = gl_WorkGroupID.y;
26+
const uint chunkIndex = gl_WorkGroupID.z;
27+
const uint position = uint(x[batchIndex]);
2928

30-
barrier();
31-
32-
const uint position = sharedPosition;
3329
const BatchInfo info = infos[batchIndex];
30+
const uint offset = chunkIndex * CHUNK_SIZE;
31+
const uint yOffset = info.outputOffset + offset;
32+
const uint wOffset = position * info.outputSizeX + offset;
3433

35-
const uint outputSizeX = info.outputSizeX;
36-
const uint yOffset = info.outputOffset;
37-
const uint wOffset = position * outputSizeX;
38-
39-
for (uint i = threadIndex; i < outputSizeX; i += N_THREADS) {
34+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
4035
y[yOffset + i] = weight[wOffset + i];
4136
}
4237
}
Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define CHUNK_SIZE 4
46
#define N_BATCHES 32
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -21,21 +23,15 @@ layout(binding = 3) readonly uniform configBuffer {
2123
layout(binding = 4) readonly buffer multiplierBuffer { float m[]; };
2224

2325
void main() {
24-
const uint threadIndex = gl_LocalInvocationID.x;
25-
const uint nWorkGroups = gl_NumWorkGroups.z;
2626
const uint batchIndex = gl_WorkGroupID.y;
27-
const uint workGroupIndex = gl_WorkGroupID.z;
27+
const uint chunkIndex = gl_WorkGroupID.z;
2828

2929
const BatchInfo info = infos[batchIndex];
30-
const uint slice = info.inputSizeX / nWorkGroups;
31-
const uint rest = info.inputSizeX % nWorkGroups;
32-
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
33-
34-
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
30+
const uint offset = chunkIndex * CHUNK_SIZE;
3531
const uint xyOffset = info.inputOffset + offset;
3632
const uint mOffset = info.inputSizeX * batchIndex + offset;
3733

38-
for (uint i = threadIndex; i < dim; i += N_THREADS) {
34+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
3935
y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i];
4036
}
4137
}
Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
45
#define N_BATCHES 32
6+
#define CHUNK_SIZE 4
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -22,27 +24,20 @@ layout(binding = 4) readonly uniform configBuffer {
2224
};
2325
layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; };
2426

25-
shared float sharedS;
26-
2727
void main() {
28-
const uint threadIndex = uint(gl_LocalInvocationID.x);
2928
const uint batchIndex = gl_WorkGroupID.y;
30-
const uint colIndex = gl_WorkGroupID.z;
31-
32-
if (threadIndex == 0) {
33-
sharedS = invRms[batchIndex * nColumns + colIndex];
34-
}
35-
36-
barrier();
29+
const uint chunkIndex = gl_WorkGroupID.z;
3730

3831
const BatchInfo info = infos[batchIndex];
3932
const uint dim = info.inputSizeX / nColumns;
40-
const uint offset = dim * colIndex;
33+
const uint offset = chunkIndex * CHUNK_SIZE;
34+
const uint colIndex = offset / dim;
35+
const float s = invRms[batchIndex * nColumns + colIndex];
36+
4137
const uint xOffset = info.inputOffset + offset;
4238
const uint yOffset = info.outputOffset + offset;
43-
const float s = sharedS;
4439

45-
for (uint i = threadIndex; i < dim; i += N_THREADS) {
46-
y[yOffset + i] = (x[xOffset + i] * s) * weight[i];
40+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
41+
y[yOffset + i] = (x[xOffset + i] * s) * weight[(offset + i) % dim];
4742
}
4843
}
Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define CHUNK_SIZE 4
46
#define N_BATCHES 32
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -23,28 +25,17 @@ layout(binding = 4) readonly buffer indexBuffer { float indexes[]; };
2325
shared uint sharedIndex;
2426

2527
void main() {
26-
const uint threadIndex = gl_LocalInvocationID.x;
27-
const uint nWorkGroups = gl_NumWorkGroups.z;
2828
const uint batchIndex = gl_WorkGroupID.y;
29-
const uint workGroupIndex = gl_WorkGroupID.z;
30-
31-
if (threadIndex == 0) {
32-
sharedIndex = uint(indexes[batchIndex]);
33-
}
34-
35-
barrier();
29+
const uint chunkIndex = gl_WorkGroupID.z;
3630

37-
BatchInfo info = infos[batchIndex];
38-
const uint slice = info.inputSizeX / nWorkGroups;
39-
const uint rest = info.inputSizeX % nWorkGroups;
40-
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
31+
const uint index = uint(indexes[batchIndex]);
4132

42-
const uint index = sharedIndex;
43-
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);;
33+
const BatchInfo info = infos[batchIndex];
34+
const uint offset = chunkIndex * CHUNK_SIZE;
4435
const uint xOffset = info.inputOffset + offset;;
4536
const uint yOffset = index * info.inputSizeX + offset;
4637

47-
for (uint i = threadIndex; i < dim; i += N_THREADS) {
38+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
4839
y[yOffset + i] = x[xOffset + i];
4940
}
5041
}

src/nn/vulkan/silu-forward-f32-f32.comp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
45
#define N_BATCHES 32
6+
#define CHUNK_SIZE 4
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -17,21 +19,15 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1719
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
1820

1921
void main() {
20-
const uint threadIndex = gl_LocalInvocationID.x;
21-
const uint nWorkGroups = gl_NumWorkGroups.z;
2222
const uint batchIndex = gl_WorkGroupID.y;
23-
const uint workGroupIndex = gl_WorkGroupID.z;
23+
const uint chunkIndex = gl_WorkGroupID.z;
2424

2525
const BatchInfo info = infos[batchIndex];
26-
const uint slice = info.inputSizeX / nWorkGroups;
27-
const uint rest = info.inputSizeX % nWorkGroups;
28-
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
29-
30-
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
26+
const uint offset = chunkIndex * CHUNK_SIZE;
3127
const uint xOffset = info.inputOffset + offset;
3228
const uint yOffset = info.outputOffset + offset;
3329

34-
for (uint i = threadIndex; i < dim; i += N_THREADS) {
30+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
3531
float v = x[xOffset + i];
3632
y[yOffset + i] = v / (1.0f + exp(-v));
3733
}

0 commit comments

Comments
 (0)