Skip to content

Commit c9ca9d5

Browse files
committed
vulkan: move common FA index/stride setup code to flash_attn_base.comp
1 parent 0e47b49 commit c9ca9d5

File tree

4 files changed

+60
-141
lines changed

4 files changed

+60
-141
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -45,58 +45,12 @@ void main() {
4545
init_iq_shmem(gl_WorkGroupSize);
4646
#endif
4747

48-
const uint32_t tid = gl_LocalInvocationIndex;
49-
const uint32_t N = p.N;
50-
const uint32_t KV = p.KV;
48+
init_indices();
5149

50+
const uint32_t tid = gl_LocalInvocationIndex;
5251
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
5352
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
5453

55-
uint32_t i = gl_WorkGroupID.x;
56-
uint32_t split_k_index = 0;
57-
58-
if (p.k_num > 1) {
59-
i = 0;
60-
split_k_index = gl_WorkGroupID.x;
61-
}
62-
63-
const uint32_t Tr = CEIL_DIV(N, Br);
64-
65-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
66-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
67-
68-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
69-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
70-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
71-
const uint32_t iq3 = gl_WorkGroupID.z;
72-
73-
// broadcast factors
74-
const uint32_t rk2 = p.neq2/p.nek2;
75-
const uint32_t rk3 = p.neq3/p.nek3;
76-
77-
const uint32_t rv2 = p.neq2/p.nev2;
78-
const uint32_t rv3 = p.neq3/p.nev3;
79-
80-
// k indices
81-
const uint32_t ik3 = iq3 / rk3;
82-
const uint32_t ik2 = iq2 / rk2;
83-
84-
// v indices
85-
const uint32_t iv3 = iq3 / rv3;
86-
const uint32_t iv2 = iq2 / rv2;
87-
88-
// nb?1 are already divided by the type size and are in units of elements.
89-
// When using grouped query attention, Q is indexed by iq2, so the stride
90-
// should be nb02 (which is in bytes).
91-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
92-
uint32_t k_stride = p.nb11;
93-
uint32_t v_stride = p.nb21;
94-
// When using grouped query attention, all rows use the same mask (stride 0).
95-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
96-
// that prevents the compiler from folding the "&" through the select
97-
// and breaking the alignment detection.
98-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
99-
10054
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
10155

10256
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,58 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
105105

106106
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107107
}
108+
109+
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
110+
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
111+
q_stride, k_stride, v_stride, m_stride;
112+
113+
void init_indices()
114+
{
115+
N = p.N;
116+
KV = p.KV;
117+
118+
i = gl_WorkGroupID.x;
119+
split_k_index = 0;
120+
121+
if (p.k_num > 1) {
122+
i = 0;
123+
split_k_index = gl_WorkGroupID.x;
124+
}
125+
126+
Tr = CEIL_DIV(N, Br);
127+
128+
start_j = split_k_index * p.split_kv / Bc;
129+
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
130+
131+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
132+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
133+
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
134+
iq3 = gl_WorkGroupID.z;
135+
136+
// broadcast factors
137+
rk2 = p.neq2/p.nek2;
138+
rk3 = p.neq3/p.nek3;
139+
140+
rv2 = p.neq2/p.nev2;
141+
rv3 = p.neq3/p.nev3;
142+
143+
// k indices
144+
ik3 = iq3 / rk3;
145+
ik2 = iq2 / rk2;
146+
147+
// v indices
148+
iv3 = iq3 / rv3;
149+
iv2 = iq2 / rv2;
150+
151+
// nb?1 are already divided by the type size and are in units of elements.
152+
// When using grouped query attention, Q is indexed by iq2, so the stride
153+
// should be nb02 (which is in bytes).
154+
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
155+
k_stride = p.nb11;
156+
v_stride = p.nb21;
157+
// When using grouped query attention, all rows use the same mask (stride 0).
158+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
159+
// that prevents the compiler from folding the "&" through the select
160+
// and breaking the alignment detection.
161+
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
162+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ void main() {
6161
init_iq_shmem(gl_WorkGroupSize);
6262
#endif
6363

64+
init_indices();
65+
6466
const uint32_t tid = gl_LocalInvocationIndex;
65-
const uint32_t N = p.N;
66-
const uint32_t KV = p.KV;
6767

6868
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
6969
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
@@ -72,51 +72,6 @@ void main() {
7272

7373
#define tile_row(r) (row_tid * rows_per_thread + (r))
7474

75-
uint32_t i = gl_WorkGroupID.x;
76-
uint32_t split_k_index = 0;
77-
78-
if (p.k_num > 1) {
79-
i = 0;
80-
split_k_index = gl_WorkGroupID.x;
81-
}
82-
83-
const uint32_t Tr = CEIL_DIV(N, Br);
84-
85-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
86-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
87-
88-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
89-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
90-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
91-
const uint32_t iq3 = gl_WorkGroupID.z;
92-
93-
// broadcast factors
94-
const uint32_t rk2 = p.neq2/p.nek2;
95-
const uint32_t rk3 = p.neq3/p.nek3;
96-
97-
const uint32_t rv2 = p.neq2/p.nev2;
98-
const uint32_t rv3 = p.neq3/p.nev3;
99-
100-
// k indices
101-
const uint32_t ik3 = iq3 / rk3;
102-
const uint32_t ik2 = iq2 / rk2;
103-
104-
// v indices
105-
const uint32_t iv3 = iq3 / rv3;
106-
const uint32_t iv2 = iq2 / rv2;
107-
108-
// nb?1 are already divided by the type size and are in units of elements.
109-
// When using grouped query attention, Q is indexed by iq2, so the stride
110-
// should be nb02 (which is in bytes).
111-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
112-
uint32_t k_stride = p.nb11;
113-
uint32_t v_stride = p.nb21;
114-
// When using grouped query attention, all rows use the same mask (stride 0).
115-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
116-
// that prevents the compiler from folding the "&" through the select
117-
// and breaking the alignment detection.
118-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
119-
12075
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
12176

12277
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -73,41 +73,7 @@ void main() {
7373
init_iq_shmem(gl_WorkGroupSize);
7474
#endif
7575

76-
const uint32_t N = p.N;
77-
const uint32_t KV = p.KV;
78-
79-
uint32_t i = gl_WorkGroupID.x;
80-
uint32_t split_k_index = 0;
81-
82-
if (p.k_num > 1) {
83-
i = 0;
84-
split_k_index = gl_WorkGroupID.x;
85-
}
86-
87-
const uint32_t Tr = CEIL_DIV(N, Br);
88-
89-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
90-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
91-
92-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
93-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
94-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
95-
const uint32_t iq3 = gl_WorkGroupID.z;
96-
97-
// broadcast factors
98-
const uint32_t rk2 = p.neq2/p.nek2;
99-
const uint32_t rk3 = p.neq3/p.nek3;
100-
101-
const uint32_t rv2 = p.neq2/p.nev2;
102-
const uint32_t rv3 = p.neq3/p.nev3;
103-
104-
// k indices
105-
const uint32_t ik3 = iq3 / rk3;
106-
const uint32_t ik2 = iq2 / rk2;
107-
108-
// v indices
109-
const uint32_t iv3 = iq3 / rv3;
110-
const uint32_t iv2 = iq2 / rv2;
76+
init_indices();
11177

11278
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
11379
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
@@ -124,17 +90,6 @@ void main() {
12490
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
12591
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
12692

127-
// nb?1 are already divided by the type size and are in units of elements.
128-
// When using grouped query attention, Q is indexed by iq2, so the stride
129-
// should be nb02 (which is in bytes).
130-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
131-
uint32_t k_stride = p.nb11;
132-
uint32_t v_stride = p.nb21;
133-
// When using grouped query attention, all rows use the same mask (stride 0).
134-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
135-
// that prevents the compiler from folding the "&" through the select
136-
// and breaking the alignment detection.
137-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
13893
// hint to the compiler that strides are aligned for the aligned variant of the shader
13994
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
14095
{

0 commit comments

Comments
 (0)