11#version 450
22
3- #extension GL_EXT_shader_16bit_storage : require
3+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
4+ #extension GL_EXT_control_flow_attributes : enable
45
56layout (push_constant) uniform parameter
67{
@@ -11,26 +12,32 @@ layout (push_constant) uniform parameter
1112 float m0;
1213 float m1;
1314 uint n_head_log2;
15+ uint nrows_x;
1416} p;
1517
1618#include "types.comp"
1719
18- #extension GL_EXT_control_flow_attributes : enable
19- #define BLOCK_SIZE 512
20-
21- layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
20+ layout(constant_id = 0) const uint BLOCK_SIZE = 32;
21+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2222
2323layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
2424layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
2525layout (binding = 2) buffer D {D_TYPE data_d[];};
2626
2727shared FLOAT_TYPE vals[BLOCK_SIZE];
2828
29- void main() {
29+ // num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
30+ // over all the columns. The main function tries to pass a constant here,
31+ // as if it were a template function, to allow unrolling.
32+ void soft_max(uint num_iters) {
3033 const uint tid = gl_LocalInvocationID.x;
3134 const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
3235 const uint rowy = rowx % p.KY;
3336
37+ if (rowx >= p.nrows_x) {
38+ return;
39+ }
40+
3441 float slope = 1.0f;
3542
3643 // ALiBi
@@ -46,19 +53,37 @@ void main() {
4653 // Find max
4754 FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
4855
49- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
56+ // Cache values while we compute the max, so we don't need to read them
57+ // again when we're ready to compute exp(x-max).
58+ const uint DATA_CACHE_SIZE = 16;
59+ FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
60+
61+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
5062 const uint col = col0 + tid;
5163
52- if (col >= p.KX) {
53- break;
64+ FLOAT_TYPE a = FLOAT_TYPE(0);
65+ if (col < p.KX) {
66+ a = data_a[rowx * p.KX + col];
5467 }
5568
56- max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
69+ FLOAT_TYPE b = FLOAT_TYPE(0);
70+ if (p.KY > 0 && col < p.KX) {
71+ b = data_b[rowy * p.KX + col];
72+ }
73+
74+ FLOAT_TYPE v = a * p.scale + slope * b;
75+
76+ max_val = max(max_val, v);
77+
78+ if (idx < DATA_CACHE_SIZE) {
79+ data_cache[idx] = v;
80+ }
5781 }
58- vals[tid] = max_val;
5982
83+ // reduce across the workgroup
84+ vals[tid] = max_val;
6085 barrier();
61- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
86+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
6287 if (tid < s) {
6388 vals[tid] = max(vals[tid], vals[tid + s]);
6489 }
@@ -68,39 +93,80 @@ void main() {
6893 max_val = vals[0];
6994 barrier();
7095
71- // Sum up values
72- vals[tid] = FLOAT_TYPE(0.0f);
96+ FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
7397
74- [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
98+ // Compute sum{exp(x - max)}
99+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
75100 const uint col = col0 + tid;
76101
77102 if (col >= p.KX) {
78103 break;
79104 }
80105
106+ // compute exp(a*scale+b*slope), add it to sum, and cache the new value
107+ // in data_cache if possible.
81108 const uint i = rowx * p.KX + col;
82- const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
83- vals[tid] += val;
84- data_d[i] = D_TYPE(val);
109+ FLOAT_TYPE val;
110+ if (idx < DATA_CACHE_SIZE) {
111+ val = exp(data_cache[idx] - max_val);
112+ } else {
113+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
114+ }
115+ sum += val;
116+ if (idx < DATA_CACHE_SIZE) {
117+ data_cache[idx] = val;
118+ } else {
119+ data_d[i] = D_TYPE(val);
120+ }
85121 }
86122
123+ // reduce across the workgroup
124+ vals[tid] = sum;
87125 barrier();
88- [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
126+ [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
89127 if (tid < s) {
90128 vals[tid] += vals[tid + s];
91129 }
92130 barrier();
93131 }
132+ sum = vals[0];
94133
95- const D_TYPE divisor = D_TYPE(vals[0]) ;
134+ FLOAT_TYPE rcpdivisor = 1.0/sum ;
96135
97- [[unroll]] for (uint col0 = 0; col0 < p.KX ; col0 += BLOCK_SIZE) {
136+ [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters ; col0 += BLOCK_SIZE, ++idx ) {
98137 const uint col = col0 + tid;
99138
100139 if (col >= p.KX) {
101- break;
140+ continue;
141+ }
142+
143+ if (idx < DATA_CACHE_SIZE) {
144+ data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
145+ } else {
146+ data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
102147 }
148+ }
149+ }
103150
104- data_d[rowx*p.KX + col] /= divisor;
151+ void main() {
152+ // instantiate the soft_max function for several different
153+ // dimensions, to allow loop unrolling
154+ uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
155+ if (num_blocks > 32) {
156+ soft_max(num_blocks);
157+ } else if (num_blocks > 16) {
158+ soft_max(32);
159+ } else if (num_blocks > 8) {
160+ soft_max(16);
161+ } else if (num_blocks > 4) {
162+ soft_max(8);
163+ } else if (num_blocks == 4) {
164+ soft_max(4);
165+ } else if (num_blocks == 3) {
166+ soft_max(3);
167+ } else if (num_blocks == 2) {
168+ soft_max(2);
169+ } else if (num_blocks == 1) {
170+ soft_max(1);
105171 }
106172}
0 commit comments