11#version 450
22
3+ #extension GL_EXT_control_flow_attributes : require
4+ 
35#define BLOCK_SIZE 64
46layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
57
@@ -29,12 +31,12 @@ void main() {
2931    const uint state_size = C * head_size;
3032    const uint n_seq_tokens = T / B;
3133
32-     if (tid >= head_size ||  batch_id >= B || head_id >= H) {
34+     if (batch_id >= B || head_id >= H) {
3335        return;
3436    }
3537
3638    A_TYPE state[BLOCK_SIZE];
37-     for (uint i = 0; i < head_size; i++) {
39+     [[unroll]]  for (uint i = 0; i < head_size; i++) {
3840        state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
3941                          + i * head_size + tid];
4042    }
@@ -56,7 +58,7 @@ void main() {
5658        const A_TYPE v_val = v[t];
5759        A_TYPE y = 0.0;
5860
59-         for (uint j = 0; j < head_size; j += 4) {
61+         [[unroll]]  for (uint j = 0; j < head_size; j += 4) {
6062            vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
6163            vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
6264            vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
@@ -78,7 +80,7 @@ void main() {
7880        dst[t] = y;
7981    }
8082
81-     for (uint i = 0; i < head_size; i++) {
83+     [[unroll]]  for (uint i = 0; i < head_size; i++) {
8284        dst[T * C + batch_id * state_size + head_id * head_size * head_size
8385            + i * head_size + tid] = state[i];
8486    }
0 commit comments