11#version 450
22
3- #define N_THREADS 64
3+ #extension GL_EXT_control_flow_attributes : enable
4+
5+ #define N_THREADS 256
46
57layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
68
@@ -36,7 +38,8 @@ layout(binding = 8) buffer attBufferBuffer { float att[]; };
3638
3739shared BatchInfo sharedInfo;
3840shared uint position;
39- shared float sharedSum;
41+ shared float sharedMaxScore;
42+ shared float temp[N_THREADS];
4043
4144void main() {
4245 const uint threadIndex = gl_LocalInvocationID.x;
@@ -45,7 +48,7 @@ void main() {
4548
4649 const uint kvMul = nHeads / nKvHeads;
4750 const uint headIndex = h / kvMul;
48- const float headSizeRoot = sqrt(float(headSize));
51+ const float invHeadSizeRoot = 1.0 / sqrt(float(headSize));
4952
5053
5154 if (threadIndex == 0) {
@@ -61,56 +64,65 @@ void main() {
6164 const uint kvOffset = headIndex * headSize;
6265 const uint yOffset = sharedInfo.outputOffset + h * headSize;
6366
67+ float ms = -1e10f;
6468 for (uint p = threadIndex; p <= position; p += N_THREADS) {
65- float score = 0.0;
6669 const uint kOffset = kvOffset + p * kvDim0;
70+
71+ float score = 0.0;
6772 for (uint i = 0; i < headSize; i++) {
6873 score += query[qOffset + i] * keyCache[kOffset + i];
6974 }
70- att[attOffset + p] = score / headSizeRoot;
75+ score *= invHeadSizeRoot;
76+ ms = max(ms, score);
77+ att[attOffset + p] = score;
7178 }
7279
80+ temp[threadIndex] = ms;
81+
7382 barrier();
83+ memoryBarrierShared();
7484
75- // softmax
76- if (threadIndex == 0) {
77- // TODO: split into multiple threads
78- float maxScore = att[attOffset];
79- for (uint p = 1; p <= position; p++) {
80- maxScore = max(maxScore, att[attOffset + p]);
81- }
85+ [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
86+ if (threadIndex < i)
87+ temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
88+ barrier();
89+ }
8290
83- float sum = 0.0;
84- for (uint p = 0; p <= position; p++) {
85- float v = exp(att[attOffset + p] - maxScore);
86- att[attOffset + p] = v;
87- sum += v;
88- }
89- sharedSum = sum;
91+ if (threadIndex == 0) {
92+ sharedMaxScore = temp[0];
9093 }
9194
9295 barrier();
96+ memoryBarrierShared();
9397
94- const float sum = sharedSum ;
98+ const float maxScore = sharedMaxScore ;
9599
100+ float s = 0.0;
96101 for (uint p = threadIndex; p <= position; p += N_THREADS) {
97- att[attOffset + p] /= sum;
102+ float v = exp(att[attOffset + p] - maxScore);
103+ att[attOffset + p] = v;
104+ s += v;
98105 }
99106
100- // return
101- for (uint i = threadIndex; i < headSize; i += N_THREADS) {
102- y[yOffset + i] = 0.0;
107+ temp[threadIndex] = s;
108+ barrier();
109+
110+ [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
111+ if (threadIndex < i)
112+ temp[threadIndex] += temp[threadIndex + i];
113+ barrier();
103114 }
104115
105- barrier() ;
116+ const float yScale = 1.0 / temp[0] ;
106117
107118 for (uint i = threadIndex; i < headSize; i += N_THREADS) {
108119 float sum = 0.0;
120+ const uint vOffset = kvOffset + i;
109121 for (uint p = 0; p <= position; p += 1) {
110122 const float a = att[attOffset + p];
111- const float v = valueCache[kvOffset + p * kvDim0 + i ];
123+ const float v = valueCache[vOffset + p * kvDim0];
112124 sum += v * a;
113125 }
114- y[yOffset + i] = sum;
126+ y[yOffset + i] = sum * yScale ;
115127 }
116128}
0 commit comments