@@ -40,26 +40,6 @@ float softplus(float x) {
4040
4141shared float stateC[SPLIT_H * D_STATE];
4242shared float warp_sdata[D_STATE];
43- float warp_reduce_sum(float val) {
44- const int tid = int(gl_LocalInvocationID.x);
45- int lane = tid % WARP_SIZE;
46- int warp_id = tid / WARP_SIZE;
47- int warp_offset = warp_id * WARP_SIZE;
48-
49- warp_sdata[warp_offset + lane] = val;
50- barrier();
51-
52- int offset = WARP_SIZE / 2;
53- while (offset > 0) {
54- if (lane < offset) {
55- warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + offset];
56- }
57- barrier();
58- offset >>= 1;
59- }
60-
61- return warp_sdata[warp_offset];
62- }
6343
6444void main() {
6545 const int tid = int(gl_LocalInvocationID.x);
@@ -126,7 +106,22 @@ void main() {
126106 if (idx < SPLIT_H * D_STATE) {
127107 y = stateC[idx];
128108 }
129- y = warp_reduce_sum(y);
109+
110+ int lane = tid % WARP_SIZE;
111+ int warp_id = tid / WARP_SIZE;
112+ int warp_offset = warp_id * WARP_SIZE;
113+
114+ warp_sdata[warp_offset + lane] = y;
115+ barrier();
116+
117+ [[unroll]] for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
118+ if (lane < offset) {
119+ warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + offset];
120+ }
121+ barrier();
122+ }
123+
124+ y = warp_sdata[warp_offset];
130125
131126 if (tid % WARP_SIZE == 0) {
132127 const int k = tid / WARP_SIZE + j * (D_STATE / WARP_SIZE);
0 commit comments