Skip to content

Commit 868060f

Browse files
author
Claude
committed
fix_1
1 parent 719cb2d commit 868060f

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,6 @@ float softplus(float x) {
4040

4141
shared float stateC[SPLIT_H * D_STATE];
4242
shared float warp_sdata[D_STATE];
43-
float warp_reduce_sum(float val) {
44-
const int tid = int(gl_LocalInvocationID.x);
45-
int lane = tid % WARP_SIZE;
46-
int warp_id = tid / WARP_SIZE;
47-
int warp_offset = warp_id * WARP_SIZE;
48-
49-
warp_sdata[warp_offset + lane] = val;
50-
barrier();
51-
52-
int offset = WARP_SIZE / 2;
53-
while (offset > 0) {
54-
if (lane < offset) {
55-
warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + offset];
56-
}
57-
barrier();
58-
offset >>= 1;
59-
}
60-
61-
return warp_sdata[warp_offset];
62-
}
6343

6444
void main() {
6545
const int tid = int(gl_LocalInvocationID.x);
@@ -126,7 +106,22 @@ void main() {
126106
if (idx < SPLIT_H * D_STATE) {
127107
y = stateC[idx];
128108
}
129-
y = warp_reduce_sum(y);
109+
110+
int lane = tid % WARP_SIZE;
111+
int warp_id = tid / WARP_SIZE;
112+
int warp_offset = warp_id * WARP_SIZE;
113+
114+
warp_sdata[warp_offset + lane] = y;
115+
barrier();
116+
117+
[[unroll]] for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {
118+
if (lane < offset) {
119+
warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + offset];
120+
}
121+
barrier();
122+
}
123+
124+
y = warp_sdata[warp_offset];
130125

131126
if (tid % WARP_SIZE == 0) {
132127
const int k = tid / WARP_SIZE + j * (D_STATE / WARP_SIZE);

0 commit comments

Comments
 (0)