11#version 450
22
3- #define N_THREADS 256
3+ #extension GL_EXT_control_flow_attributes : enable
4+
5+ #define CHUNK_SIZE 4
46#define N_BATCHES 32
57
6- layout(local_size_x = N_THREADS , local_size_y = 1, local_size_z = 1) in;
8+ layout(local_size_x = 1 , local_size_y = 1, local_size_z = 1) in;
79
810struct BatchInfo {
911 uint inputOffset;
@@ -23,28 +25,17 @@ layout(binding = 4) readonly buffer indexBuffer { float indexes[]; };
2325shared uint sharedIndex;
2426
2527void main() {
26- const uint threadIndex = gl_LocalInvocationID.x;
27- const uint nWorkGroups = gl_NumWorkGroups.z;
2828 const uint batchIndex = gl_WorkGroupID.y;
29- const uint workGroupIndex = gl_WorkGroupID.z;
30-
31- if (threadIndex == 0) {
32- sharedIndex = uint(indexes[batchIndex]);
33- }
34-
35- barrier();
29+ const uint chunkIndex = gl_WorkGroupID.z;
3630
37- BatchInfo info = infos[batchIndex];
38- const uint slice = info.inputSizeX / nWorkGroups;
39- const uint rest = info.inputSizeX % nWorkGroups;
40- const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
31+ const uint index = uint(indexes[batchIndex]);
4132
42- const uint index = sharedIndex ;
43- const uint dim = slice + (workGroupIndex < rest ? 1 : 0); ;
33+ const BatchInfo info = infos[batchIndex] ;
34+ const uint offset = chunkIndex * CHUNK_SIZE ;
4435 const uint xOffset = info.inputOffset + offset;;
4536 const uint yOffset = index * info.inputSizeX + offset;
4637
47- for (uint i = threadIndex ; i < dim ; i += N_THREADS ) {
38+ [[unroll]] for (uint i = 0 ; i < CHUNK_SIZE ; i++ ) {
4839 y[yOffset + i] = x[xOffset + i];
4940 }
5041}
0 commit comments