11#version 450
22
3- #extension GL_EXT_shader_16bit_storage : require
3+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
4+ #extension GL_EXT_control_flow_attributes : enable
45
56layout (push_constant) uniform parameter
67{
@@ -11,26 +12,32 @@ layout (push_constant) uniform parameter
1112    float m0;
1213    float m1;
1314    uint n_head_log2;
15+     uint nrows_x;
1416} p;
1517
1618#include "types.comp"
1719
18- #extension GL_EXT_control_flow_attributes : enable
19- #define BLOCK_SIZE 512
20- 
21- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
20+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
21+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2222
2323layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
2424layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
2525layout (binding = 2) buffer D {D_TYPE data_d[];};
2626
2727shared FLOAT_TYPE vals[BLOCK_SIZE];
2828
29- void main() {
29+ // num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
30+ // over all the columns. The main function tries to pass a constant here,
31+ // as if it were a template function, to allow unrolling.
32+ void soft_max(uint num_iters) {
3033    const uint tid = gl_LocalInvocationID.x;
3134    const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
3235    const uint rowy = rowx % p.KY;
3336
37+     if (rowx >= p.nrows_x) {
38+         return;
39+     }
40+ 
3441    float slope = 1.0f;
3542
3643    // ALiBi
@@ -46,19 +53,37 @@ void main() {
4653    // Find max
4754    FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
4855
49-     [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
56+     // Cache values while we compute the max, so we don't need to read them
57+     // again when we're ready to compute exp(x-max).
58+     const uint DATA_CACHE_SIZE = 16;
59+     FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
60+ 
61+     [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
5062        const uint col = col0 + tid;
5163
52-         if (col >= p.KX) {
53-             break;
64+         FLOAT_TYPE a = FLOAT_TYPE(0);
65+         if (col < p.KX) {
66+             a = data_a[rowx * p.KX + col];
5467        }
5568
56-         max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
69+         FLOAT_TYPE b = FLOAT_TYPE(0);
70+         if (p.KY > 0 && col < p.KX) {
71+             b = data_b[rowy * p.KX + col];
72+         }
73+ 
74+         FLOAT_TYPE v = a * p.scale + slope * b;
75+ 
76+         max_val = max(max_val, v);
77+ 
78+         if (idx < DATA_CACHE_SIZE) {
79+             data_cache[idx] = v;
80+         }
5781    }
58-     vals[tid] = max_val;
5982
83+     // reduce across the workgroup
84+     vals[tid] = max_val;
6085    barrier();
61-     [[unroll]] for (int  s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
86+     [[unroll]] for (uint  s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
6287        if (tid < s) {
6388            vals[tid] = max(vals[tid], vals[tid + s]);
6489        }
@@ -68,39 +93,80 @@ void main() {
6893    max_val = vals[0];
6994    barrier();
7095
71-     // Sum up values
72-     vals[tid] = FLOAT_TYPE(0.0f);
96+     FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
7397
74-     [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
98+     // Compute sum{exp(x - max)}
99+     [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
75100        const uint col = col0 + tid;
76101
77102        if (col >= p.KX) {
78103            break;
79104        }
80105
106+         // compute exp(a*scale+b*slope), add it to sum, and cache the new value
107+         // in data_cache if possible.
81108        const uint i = rowx * p.KX + col;
82-         const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
83-         vals[tid] += val;
84-         data_d[i] = D_TYPE(val);
109+         FLOAT_TYPE val;
110+         if (idx < DATA_CACHE_SIZE) {
111+             val = exp(data_cache[idx] - max_val);
112+         } else {
113+             val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
114+         }
115+         sum += val;
116+         if (idx < DATA_CACHE_SIZE) {
117+             data_cache[idx] = val;
118+         } else {
119+             data_d[i] = D_TYPE(val);
120+         }
85121    }
86122
123+     // reduce across the workgroup
124+     vals[tid] = sum;
87125    barrier();
88-     [[unroll]] for (int  s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
126+     [[unroll]] for (uint  s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
89127        if (tid < s) {
90128            vals[tid] += vals[tid + s];
91129        }
92130        barrier();
93131    }
132+     sum = vals[0];
94133
95-     const D_TYPE divisor = D_TYPE(vals[0]) ;
134+     FLOAT_TYPE rcpdivisor = 1.0/sum ;
96135
97-     [[unroll]] for (uint col0 = 0; col0  < p.KX ; col0 += BLOCK_SIZE) {
136+     [[unroll]] for (uint col0 = 0, idx = 0; idx  < num_iters ; col0 += BLOCK_SIZE, ++idx ) {
98137        const uint col = col0 + tid;
99138
100139        if (col >= p.KX) {
101-             break;
140+             continue;
141+         }
142+ 
143+         if (idx < DATA_CACHE_SIZE) {
144+             data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
145+         } else {
146+             data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
102147        }
148+     }
149+ }
103150
104-         data_d[rowx*p.KX + col] /= divisor;
151+ void main() {
152+     // instantiate the soft_max function for several different
153+     // dimensions, to allow loop unrolling
154+     uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
155+     if (num_blocks > 32) {
156+         soft_max(num_blocks);
157+     } else if (num_blocks > 16) {
158+         soft_max(32);
159+     } else if (num_blocks > 8) {
160+         soft_max(16);
161+     } else if (num_blocks > 4) {
162+         soft_max(8);
163+     } else if (num_blocks == 4) {
164+         soft_max(4);
165+     } else if (num_blocks == 3) {
166+         soft_max(3);
167+     } else if (num_blocks == 2) {
168+         soft_max(2);
169+     } else if (num_blocks == 1) {
170+         soft_max(1);
105171    }
106172}
0 commit comments