Skip to content

Commit 6ea605d

Browse files
committed
add [[unroll]] and remove unnecessary conditions
1 parent 64c16c4 commit 6ea605d

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#version 450
22

3+
#extension GL_EXT_control_flow_attributes : require
4+
35
#define BLOCK_SIZE 64
46
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
57

@@ -29,12 +31,12 @@ void main() {
2931
const uint state_size = C * head_size;
3032
const uint n_seq_tokens = T / B;
3133

32-
if (tid >= head_size || batch_id >= B || head_id >= H) {
34+
if (batch_id >= B || head_id >= H) {
3335
return;
3436
}
3537

3638
A_TYPE state[BLOCK_SIZE];
37-
for (uint i = 0; i < head_size; i++) {
39+
[[unroll]] for (uint i = 0; i < head_size; i++) {
3840
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
3941
+ i * head_size + tid];
4042
}
@@ -56,7 +58,7 @@ void main() {
5658
const A_TYPE v_val = v[t];
5759
A_TYPE y = 0.0;
5860

59-
for (uint j = 0; j < head_size; j += 4) {
61+
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
6062
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
6163
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
6264
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
@@ -78,7 +80,7 @@ void main() {
7880
dst[t] = y;
7981
}
8082

81-
for (uint i = 0; i < head_size; i++) {
83+
[[unroll]] for (uint i = 0; i < head_size; i++) {
8284
dst[T * C + batch_id * state_size + head_id * head_size * head_size
8385
+ i * head_size + tid] = state[i];
8486
}

0 commit comments

Comments
 (0)