|
1 | 1 | #version 450 |
2 | 2 |
|
3 | 3 | #extension GL_EXT_control_flow_attributes : require |
| 4 | +#extension GL_KHR_shader_subgroup_basic : require |
| 5 | +#extension GL_KHR_shader_subgroup_arithmetic : require |
4 | 6 |
|
5 | 7 | #include "types.glsl" |
6 | 8 |
|
@@ -39,29 +41,6 @@ float softplus(float x) { |
39 | 41 | } |
40 | 42 |
|
41 | 43 | shared float stateC[SPLIT_H * D_STATE]; |
42 | | -shared float warp_sdata[D_STATE]; |
43 | | -float warp_reduce_sum(float val) { |
44 | | - const int tid = int(gl_LocalInvocationID.x); |
45 | | - int lane = tid % WARP_SIZE; |
46 | | - int warp_id = tid / WARP_SIZE; |
47 | | - int warp_offset = warp_id * WARP_SIZE; |
48 | | - |
49 | | - warp_sdata[warp_offset + lane] = val; |
50 | | - barrier(); |
51 | | - |
52 | | - if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
53 | | - barrier(); |
54 | | - if (lane < 8) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 8]; |
55 | | - barrier(); |
56 | | - if (lane < 4) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 4]; |
57 | | - barrier(); |
58 | | - if (lane < 2) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 2]; |
59 | | - barrier(); |
60 | | - if (lane < 1) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 1]; |
61 | | - barrier(); |
62 | | - |
63 | | - return warp_sdata[warp_offset]; |
64 | | -} |
65 | 44 |
|
66 | 45 | void main() { |
67 | 46 | const int tid = int(gl_LocalInvocationID.x); |
@@ -128,9 +107,10 @@ void main() { |
128 | 107 | if (idx < SPLIT_H * D_STATE) { |
129 | 108 | y = stateC[idx]; |
130 | 109 | } |
131 | | - y = warp_reduce_sum(y); |
132 | 110 |
|
133 | | - if (tid % WARP_SIZE == 0) { |
| 111 | + y = subgroupAdd(y); |
| 112 | + |
| 113 | + if (gl_SubgroupInvocationID == 0) { |
134 | 114 | const int k = tid / WARP_SIZE + j * (D_STATE / WARP_SIZE); |
135 | 115 | d[y_base_idx + uint(i) * uint(stride_y) + uint(k)] = y; |
136 | 116 | } |
|
0 commit comments