Skip to content

Commit 6c379e4

Browse files
committed
vulkan: use aligned loads for flash attention mask
Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
1 parent d3bd719 commit 6c379e4

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ void main() {
201201
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202202
uint32_t k_stride = p.nb11;
203203
uint32_t v_stride = p.nb21;
204+
// When using grouped query attention, all rows use the same mask (stride 0).
205+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206+
// that prevents the compiler from folding the "&" through the select
207+
// and breaking the alignment detection.
208+
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
204209
// hint to the compiler that strides are aligned for the aligned variant of the shader
205210
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
206211
{
@@ -209,6 +214,7 @@ void main() {
209214
k_stride &= ~7;
210215
v_stride &= ~7;
211216
#endif
217+
m_stride &= ~7;
212218
}
213219
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
214220
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
261267
if (p.mask != 0) {
262268
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
263269
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
264-
// When using grouped query attention, all rows use the same mask.
265-
if (p.gqa_ratio > 1) {
266-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
267-
}
270+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
268271

269272
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
270273

0 commit comments

Comments
 (0)