Skip to content

Commit 0e47b49

Browse files
committed
vulkan: move common FA code to flash_attn_base.comp
1 parent 5e7d95e commit 0e47b49

File tree

4 files changed

+110
-276
lines changed

4 files changed

+110
-276
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -9,60 +9,13 @@
99
#extension GL_KHR_shader_subgroup_shuffle : enable
1010

1111
#include "types.comp"
12+
#include "flash_attn_base.comp"
1213

13-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
14-
15-
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
16-
layout (constant_id = 1) const uint32_t Br = 1;
17-
layout (constant_id = 2) const uint32_t Bc = 32;
18-
layout (constant_id = 3) const uint32_t D = 32;
19-
20-
layout (constant_id = 5) const uint32_t D_split = 16;
2114
const uint32_t D_per_thread = D / D_split;
2215

2316
const uint32_t cols_per_iter = WorkGroupSize / D_split;
2417
const uint32_t cols_per_thread = Bc / cols_per_iter;
2518

26-
layout (push_constant) uniform parameter {
27-
uint32_t N;
28-
uint32_t KV;
29-
30-
uint32_t ne1;
31-
uint32_t ne2;
32-
uint32_t ne3;
33-
34-
uint32_t neq2;
35-
uint32_t neq3;
36-
uint32_t nek2;
37-
uint32_t nek3;
38-
uint32_t nev2;
39-
uint32_t nev3;
40-
uint32_t nem1;
41-
42-
uint32_t nb01;
43-
uint32_t nb02;
44-
uint32_t nb03;
45-
uint32_t nb11;
46-
uint32_t nb12;
47-
uint32_t nb13;
48-
uint32_t nb21;
49-
uint32_t nb22;
50-
uint32_t nb23;
51-
uint32_t nb31;
52-
53-
float scale;
54-
float max_bias;
55-
float logit_softcap;
56-
57-
uint32_t mask;
58-
uint32_t n_head_log2;
59-
float m0;
60-
float m1;
61-
62-
uint32_t gqa_ratio;
63-
uint32_t split_kv;
64-
uint32_t k_num;
65-
} p;
6619

6720
layout (binding = 0) readonly buffer Q {float data_q[];};
6821
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
7124
layout (binding = 2) readonly buffer V {float16_t data_v[];};
7225
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
7326
layout (binding = 3) readonly buffer M {float16_t data_m[];};
74-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
75-
76-
#if defined(A_TYPE_PACKED16)
77-
#define BINDING_IDX_K 0
78-
#define BINDING_IDX_V 1
79-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
80-
#endif
81-
82-
#if defined(DATA_A_Q4_0)
83-
#define BLOCK_BYTE_SIZE 18
84-
85-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
86-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
87-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
88-
uint shift = (iqs & 0x10) >> 2;
89-
vui_lo >>= shift;
90-
vui_hi >>= shift;
91-
92-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
93-
}
94-
#endif
95-
96-
#if defined(DATA_A_Q8_0)
97-
#define BLOCK_BYTE_SIZE 34
98-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
99-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
100-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
101-
102-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
103-
}
104-
#endif
105-
106-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
10727

10828
// Store the output when doing grouped query attention.
10929
// Rows index by Q's dimension 2, and the first N rows are valid.
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
11434
return elem;
11535
}
11636

117-
// Store column zero. This is used to save per-row m and L values for split_k.
118-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
119-
{
120-
if (r < N && c == 0) {
121-
uint32_t offset = iq2 + r;
122-
data_o[o_offset + offset] = D_TYPE(elem);
123-
}
124-
return elem;
125-
}
126-
127-
// Load the slope matrix, indexed by Q's dimension 2.
128-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
129-
{
130-
const uint32_t h = iq2 + (r % p.gqa_ratio);
131-
132-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
133-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
134-
135-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
136-
}
137-
13837
shared FLOAT_TYPE tmpsh[WorkGroupSize];
13938
shared vec4 tmpshv4[WorkGroupSize];
14039

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
2+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3+
4+
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5+
layout (constant_id = 1) const uint32_t Br = 1;
6+
layout (constant_id = 2) const uint32_t Bc = 32;
7+
layout (constant_id = 3) const uint32_t D = 32;
8+
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
9+
layout (constant_id = 5) const uint32_t D_split = 16;
10+
11+
12+
layout (push_constant) uniform parameter {
13+
uint32_t N;
14+
uint32_t KV;
15+
16+
uint32_t ne1;
17+
uint32_t ne2;
18+
uint32_t ne3;
19+
20+
uint32_t neq2;
21+
uint32_t neq3;
22+
uint32_t nek2;
23+
uint32_t nek3;
24+
uint32_t nev2;
25+
uint32_t nev3;
26+
uint32_t nem1;
27+
28+
uint32_t nb01;
29+
uint32_t nb02;
30+
uint32_t nb03;
31+
uint32_t nb11;
32+
uint32_t nb12;
33+
uint32_t nb13;
34+
uint32_t nb21;
35+
uint32_t nb22;
36+
uint32_t nb23;
37+
uint32_t nb31;
38+
39+
float scale;
40+
float max_bias;
41+
float logit_softcap;
42+
43+
uint32_t mask;
44+
uint32_t n_head_log2;
45+
float m0;
46+
float m1;
47+
48+
uint32_t gqa_ratio;
49+
uint32_t split_kv;
50+
uint32_t k_num;
51+
} p;
52+
53+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
54+
55+
#if defined(A_TYPE_PACKED16)
56+
#define BINDING_IDX_K 0
57+
#define BINDING_IDX_V 1
58+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
59+
#endif
60+
61+
#if defined(DATA_A_Q4_0)
62+
#define BLOCK_BYTE_SIZE 18
63+
64+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
65+
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
66+
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
67+
uint shift = (iqs & 0x10) >> 2;
68+
vui_lo >>= shift;
69+
vui_hi >>= shift;
70+
71+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
72+
}
73+
#endif
74+
75+
#if defined(DATA_A_Q8_0)
76+
#define BLOCK_BYTE_SIZE 34
77+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
78+
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
79+
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
80+
81+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
82+
}
83+
#endif
84+
85+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
86+
87+
88+
// Store column zero. This is used to save per-row m and L values for split_k.
89+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
90+
{
91+
if (r < N && c == 0) {
92+
uint32_t offset = iq2 + r;
93+
data_o[o_offset + offset] = D_TYPE(elem);
94+
}
95+
return elem;
96+
}
97+
98+
// Load the slope matrix, indexed by Q's dimension 2.
99+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
100+
{
101+
const uint32_t h = iq2 + (r % p.gqa_ratio);
102+
103+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
105+
106+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -11,61 +11,14 @@
1111
#extension GL_KHR_cooperative_matrix : enable
1212

1313
#include "types.comp"
14-
15-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
16-
17-
layout (constant_id = 1) const uint32_t Br = 1;
18-
layout (constant_id = 2) const uint32_t Bc = 32;
19-
layout (constant_id = 3) const uint32_t D = 32;
20-
21-
layout (constant_id = 5) const uint32_t D_split = 16;
14+
#include "flash_attn_base.comp"
2215

2316
const uint32_t D_per_thread = D / D_split;
2417
const uint32_t row_split = 4;
2518
const uint32_t rows_per_thread = Br / row_split;
2619
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
2720
const uint32_t cols_per_thread = Bc / cols_per_iter;
2821

29-
layout (push_constant) uniform parameter {
30-
uint32_t N;
31-
uint32_t KV;
32-
33-
uint32_t ne1;
34-
uint32_t ne2;
35-
uint32_t ne3;
36-
37-
uint32_t neq2;
38-
uint32_t neq3;
39-
uint32_t nek2;
40-
uint32_t nek3;
41-
uint32_t nev2;
42-
uint32_t nev3;
43-
uint32_t nem1;
44-
45-
uint32_t nb01;
46-
uint32_t nb02;
47-
uint32_t nb03;
48-
uint32_t nb11;
49-
uint32_t nb12;
50-
uint32_t nb13;
51-
uint32_t nb21;
52-
uint32_t nb22;
53-
uint32_t nb23;
54-
uint32_t nb31;
55-
56-
float scale;
57-
float max_bias;
58-
float logit_softcap;
59-
60-
uint32_t mask;
61-
uint32_t n_head_log2;
62-
float m0;
63-
float m1;
64-
65-
uint32_t gqa_ratio;
66-
uint32_t split_kv;
67-
uint32_t k_num;
68-
} p;
6922

7023
layout (binding = 0) readonly buffer Q {float data_q[];};
7124
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
7427
layout (binding = 2) readonly buffer V {float16_t data_v[];};
7528
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
7629
layout (binding = 3) readonly buffer M {float16_t data_m[];};
77-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
78-
79-
#if defined(A_TYPE_PACKED16)
80-
#define BINDING_IDX_K 0
81-
#define BINDING_IDX_V 1
82-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
83-
#endif
84-
85-
#if defined(DATA_A_Q4_0)
86-
#define BLOCK_BYTE_SIZE 18
87-
88-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
89-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
90-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
91-
uint shift = (iqs & 0x10) >> 2;
92-
vui_lo >>= shift;
93-
vui_hi >>= shift;
94-
95-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
96-
}
97-
#endif
98-
99-
#if defined(DATA_A_Q8_0)
100-
#define BLOCK_BYTE_SIZE 34
101-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
102-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
103-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
104-
105-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
106-
}
107-
#endif
108-
109-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
11030

11131
// Store the output when doing grouped query attention.
11232
// Rows index by Q's dimension 2, and the first N rows are valid.
@@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
11737
return elem;
11838
}
11939

120-
// Store column zero. This is used to save per-row m and L values for split_k.
121-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
122-
{
123-
if (r < N && c == 0) {
124-
uint32_t offset = iq2 + r;
125-
data_o[o_offset + offset] = D_TYPE(elem);
126-
}
127-
return elem;
128-
}
129-
130-
// Load the slope matrix, indexed by Q's dimension 2.
131-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
132-
{
133-
const uint32_t h = iq2 + (r % p.gqa_ratio);
134-
135-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
136-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
137-
138-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
139-
}
140-
14140
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
14241
const uint32_t MatBr = 16;
14342
const uint32_t MatBc = 16;

0 commit comments

Comments
 (0)