Skip to content

Commit 1f9c5f9

Browse files
committed
feat: cast-forward-f32-f32.
1 parent 2f1cdc2 commit 1f9c5f9

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

src/nn/nn-vulkan.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@ static void resolveShaderGroups(const NnOpConfig *opConfig, const NnUint batchSi
531531
if (outputSize.floatType == F_Q80) {
532532
groupCount[2] = outputSize.x / Q80_BLOCK_SIZE;
533533
} else {
534-
groupCount[2] = 32;
534+
constexpr NnUint chunkSize = 4;
535+
groupCount[2] = outputSize.x / chunkSize;
535536
}
536537
} else if (opConfig->code == OP_MERGE_ADD) {
537538
if (inputSize.floatType == F_Q80) {
Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#version 450
22

3-
#define N_THREADS 256
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#define CHUNK_SIZE 4
46
#define N_BATCHES 32
57

6-
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
8+
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
79

810
struct BatchInfo {
911
uint inputOffset;
@@ -17,20 +19,14 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1719
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
1820

1921
void main() {
20-
const uint threadIndex = gl_LocalInvocationID.x;
21-
const uint nWorkGroups = gl_NumWorkGroups.z;
2222
const uint batchIndex = gl_WorkGroupID.y;
23-
const uint workGroupIndex = gl_WorkGroupID.z;
24-
23+
const uint chunkIndex = gl_WorkGroupID.z;
2524
const BatchInfo info = infos[batchIndex];
26-
const uint slice = info.inputSizeX / nWorkGroups;
27-
const uint rest = info.inputSizeX % nWorkGroups;
28-
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
29-
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
25+
const uint offset = chunkIndex * CHUNK_SIZE;
3026
const uint xOffset = info.inputOffset + offset;
3127
const uint yOffset = info.outputOffset + offset;
3228

33-
for (uint i = threadIndex; i < dim; i += N_THREADS) {
29+
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
3430
y[yOffset + i] = x[xOffset + i];
3531
}
3632
}

0 commit comments

Comments
 (0)