Skip to content

Commit 0b4bb4f

Browse files
committed
vulkan: Handle updated FA dim2/3 definition
Pack mask boolean and n_head_log2 into a single dword to keep the push constant block under the 128B limit.
1 parent 7b63a71 commit 0b4bb4f

File tree

5 files changed

+26
-24
lines changed

5 files changed

+26
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ struct vk_flash_attn_push_constants {
634634
uint32_t nev3;
635635
uint32_t nem1;
636636
uint32_t nem2;
637+
uint32_t nem3;
637638

638639
uint32_t nb01;
639640
uint32_t nb02;
@@ -649,8 +650,7 @@ struct vk_flash_attn_push_constants {
649650
float max_bias;
650651
float logit_softcap;
651652

652-
uint32_t mask;
653-
uint32_t n_head_log2;
653+
uint32_t mask_n_head_log2;
654654
float m0;
655655
float m1;
656656

@@ -6050,6 +6050,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
60506050

60516051
const uint32_t nem1 = mask ? mask->ne[1] : 0;
60526052
const uint32_t nem2 = mask ? mask->ne[2] : 0;
6053+
const uint32_t nem3 = mask ? mask->ne[3] : 0;
60536054

60546055
const uint32_t D = neq0;
60556056
uint32_t N = neq1;
@@ -6119,7 +6120,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61196120
}
61206121

61216122
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6122-
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6123+
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1 && nem2 == 1 && nem3 == 1) {
61236124
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
61246125
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
61256126
// and change addressing calculations to index Q's dimension 2.
@@ -6311,17 +6312,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63116312
}
63126313
}
63136314

6315+
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6316+
63146317
const vk_flash_attn_push_constants pc = { N, KV,
63156318
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63166319
(uint32_t)neq2, (uint32_t)neq3,
63176320
(uint32_t)nek2, (uint32_t)nek3,
63186321
(uint32_t)nev2, (uint32_t)nev3,
6319-
nem1, nem2,
6322+
nem1, nem2, nem3,
63206323
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63216324
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63226325
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63236326
scale, max_bias, logit_softcap,
6324-
mask != nullptr, n_head_log2, m0, m1,
6327+
mask_n_head_log2, m0, m1,
63256328
gqa_ratio, split_kv, split_k };
63266329

63276330
ggml_vk_sync_buffers(subctx);
@@ -10265,12 +10268,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1026510268
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1026610269
return false;
1026710270
}
10268-
// TODO: support broadcast
10269-
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10270-
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10271-
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10272-
return false;
10273-
}
1027410271
// It's straightforward to support different K/V dequant, but would
1027510272
// significantly increase the number of pipelines
1027610273
if (op->src[1]->type != op->src[2]->type) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ void main() {
100100
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101101
#endif
102102
uint32_t m_offset = 0;
103-
if (p.nem2 != 1) {
104-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
103+
if (p.nem2 != 1 || p.nem3 != 1) {
104+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
105105
}
106106

107107
[[dont_unroll]]
@@ -148,7 +148,7 @@ void main() {
148148
}
149149
}
150150

151-
if (p.mask != 0) {
151+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
152152

153153
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
154154
uint32_t c = (idx + tid) % Bc;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
2525
uint32_t nev3;
2626
uint32_t nem1;
2727
uint32_t nem2;
28+
uint32_t nem3;
2829

2930
uint32_t nb01;
3031
uint32_t nb02;
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
4041
float max_bias;
4142
float logit_softcap;
4243

43-
uint32_t mask;
44-
uint32_t n_head_log2;
44+
uint32_t mask_n_head_log2;
4545
float m0;
4646
float m1;
4747

@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
5050
uint32_t k_num;
5151
} p;
5252

53+
#define MASK_ENABLE_BIT (1<<16)
54+
#define N_LOG2_MASK 0xFFFF
55+
5356
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
5457

5558
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100103
{
101104
const uint32_t h = iq2 + (r % p.gqa_ratio);
102105

103-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107+
108+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105110

106111
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107112
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ void main() {
124124
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125125
#endif
126126
uint32_t m_offset = 0;
127-
if (p.nem2 != 1) {
128-
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
127+
if (p.nem2 != 1 || p.nem3 != 1) {
128+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
129129
}
130130

131131
[[dont_unroll]]
@@ -180,7 +180,7 @@ void main() {
180180
barrier();
181181
}
182182

183-
if (p.mask != 0) {
183+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
184184
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
185185
uint32_t c = (idx + tid) % Bc;
186186
uint32_t r = (idx + tid) / Bc;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ void main() {
131131
}
132132

133133
uint32_t m_offset = 0;
134-
if (p.nem2 != 1) {
135-
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
134+
if (p.nem2 != 1 || p.nem3 != 1) {
135+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
136136
}
137137

138138
[[dont_unroll]]
@@ -153,7 +153,7 @@ void main() {
153153
}
154154
}
155155

156-
if (p.mask != 0) {
156+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157157
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158158
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159159
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);

0 commit comments

Comments
 (0)