Skip to content

Commit 668306f

Browse files
jeffbolznvggerganov
authored andcommitted
vulkan: fix coopmat2 flash attention for non-contiguous inputs (llama/11281)
Add code similar to mul_mm_cm2 to force alignment of strides, to avoid a performance regression. Add noncontiguous FA tests in test-backend-ops. Fixes #11268.
1 parent fdc21fc commit 668306f

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,13 @@ struct vk_flash_attn_push_constants {
386386
uint32_t nev3;
387387
uint32_t nem1;
388388

389+
uint32_t nb01;
389390
uint32_t nb02;
390391
uint32_t nb03;
392+
uint32_t nb11;
391393
uint32_t nb12;
392394
uint32_t nb13;
395+
uint32_t nb21;
393396
uint32_t nb22;
394397
uint32_t nb23;
395398
uint32_t nb31;
@@ -4809,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48094812
}
48104813
assert(pipelines);
48114814

4812-
bool aligned = (KV % pipelines[1]->align) == 0;
4815+
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
4816+
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
4817+
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
4818+
4819+
bool aligned = (KV % pipelines[1]->align) == 0 &&
4820+
// the "aligned" shader variant will forcibly align strides, for performance
4821+
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
4822+
48134823
vk_pipeline pipeline = pipelines[aligned];
48144824
assert(pipeline);
48154825

@@ -4845,15 +4855,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48454855

48464856
if (ctx->device->uma) {
48474857
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4848-
ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4849-
ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4850-
ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4858+
ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
4859+
ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
4860+
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
48514861
Q_uma = d_Q != nullptr;
48524862
K_uma = d_K != nullptr;
48534863
V_uma = d_V != nullptr;
48544864
D_uma = d_D != nullptr;
48554865
if (mask) {
4856-
ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4866+
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
48574867
M_uma = d_M != nullptr;
48584868
}
48594869
}
@@ -4891,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
48914901
}
48924902
}
48934903

4894-
const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
4904+
const vk_flash_attn_push_constants pc = { N, KV,
4905+
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
4906+
(uint32_t)neq2, (uint32_t)neq3,
4907+
(uint32_t)nek2, (uint32_t)nek3,
4908+
(uint32_t)nev2, (uint32_t)nev3,
4909+
nem1,
4910+
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
4911+
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
4912+
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
4913+
nbm1,
4914+
scale, max_bias, logit_softcap,
4915+
mask != nullptr, n_head_log2, m0, m1 };
48954916
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
48964917
{
48974918
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -8668,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
86688689
ggml_tensor * src0 = tensor->src[0];
86698690
ggml_tensor * src1 = tensor->src[1];
86708691
ggml_tensor * src2 = tensor->src[2];
8692+
ggml_tensor * src3 = tensor->src[3];
86718693

86728694
void * tensor_data = tensor->data;
86738695

@@ -8730,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87308752
if (src2 != nullptr) {
87318753
std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
87328754
}
8755+
if (src3 != nullptr) {
8756+
std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8757+
}
87338758
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
87348759
std::cerr << std::endl << "Result:" << std::endl;
87358760
ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
@@ -8774,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87748799
if (src2 != nullptr) {
87758800
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
87768801
}
8802+
if (src3 != nullptr) {
8803+
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8804+
}
87778805
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
87788806
std::cerr << std::endl << "Result:" << std::endl;
87798807
ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
@@ -8796,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
87968824
if (src2 != nullptr) {
87978825
std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
87988826
}
8827+
if (src3 != nullptr) {
8828+
std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8829+
}
87998830
std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
88008831
std::cerr << std::endl << "Result:" << std::endl;
88018832
ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@ layout (push_constant) uniform parameter {
4242
uint32_t nev3;
4343
uint32_t nem1;
4444

45+
uint32_t nb01;
4546
uint32_t nb02;
4647
uint32_t nb03;
48+
uint32_t nb11;
4749
uint32_t nb12;
4850
uint32_t nb13;
51+
uint32_t nb21;
4952
uint32_t nb22;
5053
uint32_t nb23;
5154
uint32_t nb31;
@@ -146,6 +149,23 @@ void main() {
146149
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
147150
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
148151

152+
// nb?1 are already divided by the type size and are in units of elements
153+
uint32_t q_stride = p.nb01;
154+
uint32_t k_stride = p.nb11;
155+
uint32_t v_stride = p.nb21;
156+
// hint to the compiler that strides are aligned for the aligned variant of the shader
157+
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
158+
{
159+
q_stride &= ~7;
160+
#if !defined(BLOCK_SIZE)
161+
k_stride &= ~7;
162+
v_stride &= ~7;
163+
#endif
164+
}
165+
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
166+
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
167+
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
168+
149169
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
150170
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
151171

0 commit comments

Comments
 (0)