Skip to content

Commit 32614e2

Browse files
author
Claude
committed
fixup scan
1 parent 48d1afe commit 32614e2

File tree

1 file changed

+5
-25
lines changed

1 file changed

+5
-25
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#extension GL_KHR_shader_subgroup_basic : require
5+
#extension GL_KHR_shader_subgroup_arithmetic : require
46

57
#include "types.glsl"
68

@@ -39,29 +41,6 @@ float softplus(float x) {
3941
}
4042

4143
shared float stateC[SPLIT_H * D_STATE];
42-
shared float warp_sdata[D_STATE];
43-
float warp_reduce_sum(float val) {
44-
const int tid = int(gl_LocalInvocationID.x);
45-
int lane = tid % WARP_SIZE;
46-
int warp_id = tid / WARP_SIZE;
47-
int warp_offset = warp_id * WARP_SIZE;
48-
49-
warp_sdata[warp_offset + lane] = val;
50-
barrier();
51-
52-
if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16];
53-
barrier();
54-
if (lane < 8) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 8];
55-
barrier();
56-
if (lane < 4) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 4];
57-
barrier();
58-
if (lane < 2) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 2];
59-
barrier();
60-
if (lane < 1) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 1];
61-
barrier();
62-
63-
return warp_sdata[warp_offset];
64-
}
6544

6645
void main() {
6746
const int tid = int(gl_LocalInvocationID.x);
@@ -128,9 +107,10 @@ void main() {
128107
if (idx < SPLIT_H * D_STATE) {
129108
y = stateC[idx];
130109
}
131-
y = warp_reduce_sum(y);
132110

133-
if (tid % WARP_SIZE == 0) {
111+
y = subgroupAdd(y);
112+
113+
if (gl_SubgroupInvocationID == 0) {
134114
const int k = tid / WARP_SIZE + j * (D_STATE / WARP_SIZE);
135115
d[y_base_idx + uint(i) * uint(stride_y) + uint(k)] = y;
136116
}

0 commit comments

Comments
 (0)