Skip to content

Commit c99d2d5

Browse files
authored
Flash Attention Texture Compute Shader for Vulkan Backend Delegate (#12982)
Summary: Built flash attention compute shader for Vulkan backend delegate. The current implementation is not fully optimized, but is functional. This shader should speed up the SDPA process in the attention block of transformer inferencing as the previous implementation used many i/o operations. The implementation includes proper multi-query attention support for models like LLaMA, uses tiled block processing to reduce memory usage, and replaces multiple separate operations (matmul, softmax, masking) with a single efficient compute shader. Reviewed By: SS-JIA Differential Revision: D78836150 cc @SS-JIA @manuelcandales @cbilgin
1 parent 0c1acb3 commit c99d2d5

File tree

7 files changed

+444
-181
lines changed

7 files changed

+444
-181
lines changed

backends/vulkan/runtime/graph/ops/glsl/flash_attention.yaml

Lines changed: 0 additions & 10 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/glsl/flash_attention.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ void main() {
146146
}
147147
score *= scale;
148148

149+
149150
// Apply causal masking: mask if global_col > global_row + input_pos
150151
if (global_col > global_row + input_pos) {
151152
score = T(-1.0 / 0.0); // Set to negative infinity
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
flash_attention_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: flash_attention_buffer
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define T ${buffer_scalar_type(DTYPE)}
13+
${define_required_extensions(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
// Flash Attention inputs: Query, Key, Value tensors using texture storage
20+
${layout_declare_tensor(B, "rw", "t_O", DTYPE, "texture3d")}
21+
${layout_declare_tensor(B, "rw", "t_l", "float", "texture3d")}
22+
${layout_declare_tensor(B, "rw", "t_m", "float", "texture3d")}
23+
${layout_declare_tensor(B, "r", "t_Q", DTYPE, "texture3d")}
24+
${layout_declare_tensor(B, "r", "t_K", DTYPE, "texture3d")}
25+
${layout_declare_tensor(B, "r", "t_V", DTYPE, "texture3d")}
26+
27+
${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D]
28+
${layout_declare_ubo(B, "ivec4", "K_sizes")}
29+
${layout_declare_ubo(B, "ivec4", "V_sizes")}
30+
${layout_declare_ubo(B, "ivec4", "O_sizes")}
31+
32+
${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N]
33+
${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N]
34+
35+
${layout_declare_ubo(B, "float", "scale")}
36+
${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block)
37+
${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block)
38+
${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking
39+
${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads
40+
${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads
41+
42+
// Axis mapping setup for proper texture indexing
43+
${layout_declare_spec_const(C, "int", "Q_layout", "DEFAULT_LAYOUT")}
44+
const lowp ivec4 Q_axis_map = unhash_axis_map(Q_layout);
45+
const lowp int Q_packed_dim = unhash_packed_dim(Q_layout);
46+
47+
${layout_declare_spec_const(C, "int", "K_layout", "DEFAULT_LAYOUT")}
48+
const lowp ivec4 K_axis_map = unhash_axis_map(K_layout);
49+
const lowp int K_packed_dim = unhash_packed_dim(K_layout);
50+
51+
${layout_declare_spec_const(C, "int", "V_layout", "DEFAULT_LAYOUT")}
52+
const lowp ivec4 V_axis_map = unhash_axis_map(V_layout);
53+
const lowp int V_packed_dim = unhash_packed_dim(V_layout);
54+
55+
${layout_declare_spec_const(C, "int", "O_layout", "DEFAULT_LAYOUT")}
56+
const lowp ivec4 O_axis_map = unhash_axis_map(O_layout);
57+
const lowp int O_packed_dim = unhash_packed_dim(O_layout);
58+
59+
${layout_declare_spec_const(C, "int", "l_layout", "DEFAULT_LAYOUT")}
60+
const lowp ivec4 l_axis_map = unhash_axis_map(l_layout);
61+
const lowp int l_packed_dim = unhash_packed_dim(l_layout);
62+
63+
${layout_declare_spec_const(C, "int", "m_layout", "DEFAULT_LAYOUT")}
64+
const lowp ivec4 m_axis_map = unhash_axis_map(m_layout);
65+
const lowp int m_packed_dim = unhash_packed_dim(m_layout);
66+
67+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
68+
69+
// Maximum block sizes to prevent array overflow
70+
#define MAX_BR 64
71+
#define MAX_BC 128
72+
73+
// Texture access helper functions using proper axis mapping
74+
// Q_sizes, K_sizes, V_sizes, O_sizes are [D, H, N, B] (UBO layout)
75+
// l_sizes, m_sizes are [B, H, N] (UBO layout)
76+
T load_tensor_Q(int batch, int seq_pos, int head, int dim) {
77+
ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order
78+
ivec3 pos = tidx_to_pos(tidx, Q_sizes, Q_axis_map, Q_packed_dim);
79+
int component = tidx[Q_packed_dim] % 4;
80+
vec4 texel = texelFetch(t_Q, pos, 0);
81+
return T(texel[component]);
82+
}
83+
84+
T load_tensor_K(int batch, int seq_pos, int head, int dim) {
85+
ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order
86+
ivec3 pos = tidx_to_pos(tidx, K_sizes, K_axis_map, K_packed_dim);
87+
int component = tidx[K_packed_dim] % 4;
88+
vec4 texel = texelFetch(t_K, pos, 0);
89+
return T(texel[component]);
90+
}
91+
92+
T load_tensor_V(int batch, int seq_pos, int head, int dim) {
93+
ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order
94+
ivec3 pos = tidx_to_pos(tidx, V_sizes, V_axis_map, V_packed_dim);
95+
int component = tidx[V_packed_dim] % 4;
96+
vec4 texel = texelFetch(t_V, pos, 0);
97+
return T(texel[component]);
98+
}
99+
100+
T load_tensor_O(int batch, int seq_pos, int head, int dim) {
101+
ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order
102+
ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim);
103+
int component = tidx[O_packed_dim] % 4;
104+
vec4 texel = imageLoad(t_O, pos);
105+
return T(texel[component]);
106+
}
107+
108+
void store_tensor_O(int batch, int seq_pos, int head, int dim, T value) {
109+
ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order
110+
ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim);
111+
int component = tidx[O_packed_dim] % 4;
112+
vec4 texel = imageLoad(t_O, pos);
113+
texel[component] = float(value);
114+
imageStore(t_O, pos, texel);
115+
}
116+
117+
float load_tensor_l(int batch, int head, int seq_pos) {
118+
ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding)
119+
ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim);
120+
int component = tidx[l_packed_dim] % 4;
121+
vec4 texel = imageLoad(t_l, pos);
122+
return texel[component];
123+
}
124+
125+
void store_tensor_l(int batch, int head, int seq_pos, float value) {
126+
ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding)
127+
ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim);
128+
int component = tidx[l_packed_dim] % 4;
129+
vec4 texel = imageLoad(t_l, pos);
130+
texel[component] = value;
131+
imageStore(t_l, pos, texel);
132+
}
133+
134+
float load_tensor_m(int batch, int head, int seq_pos) {
135+
ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding)
136+
ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim);
137+
int component = tidx[m_packed_dim] % 4;
138+
vec4 texel = imageLoad(t_m, pos);
139+
return texel[component];
140+
}
141+
142+
void store_tensor_m(int batch, int head, int seq_pos, float value) {
143+
ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding)
144+
ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim);
145+
int component = tidx[m_packed_dim] % 4;
146+
vec4 texel = imageLoad(t_m, pos);
147+
texel[component] = value;
148+
imageStore(t_m, pos, texel);
149+
150+
}
151+
152+
void main() {
153+
// Each thread processes one row block - same as buffer version
154+
const int thread_id = int(gl_GlobalInvocationID.x);
155+
156+
// Tensor dimensions: Q_sizes = [D, H, N, B]
157+
const int head_dim = Q_sizes.x; // D (head dim)
158+
const int num_heads_val = Q_sizes.y; // H (num heads)
159+
const int seq_len = Q_sizes.z; // N (sequence length)
160+
const int batch_size = Q_sizes.w; // B (batch)
161+
162+
// Block sizes
163+
const int Br = block_size_r;
164+
const int Bc = block_size_c;
165+
166+
const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks
167+
const int total_row_blocks = batch_size * num_heads_val * Tr;
168+
169+
if (thread_id >= total_row_blocks) {
170+
return;
171+
}
172+
173+
// Decode thread_id to (batch, head, row_block)
174+
const int batch = thread_id / (num_heads_val * Tr);
175+
const int remaining = thread_id % (num_heads_val * Tr);
176+
const int head = remaining / Tr;
177+
const int row_block = remaining % Tr;
178+
179+
// Calculate row range for this block
180+
const int row_start = row_block * Br;
181+
const int row_end = min(row_start + Br, seq_len);
182+
const int actual_Br = row_end - row_start;
183+
184+
// STEP 1: Initialize only this thread's row block
185+
// Each thread initializes its own rows to avoid cross-workgroup synchronization issues
186+
for (int r = 0; r < actual_Br; r++) {
187+
const int seq_pos = row_start + r;
188+
189+
// Initialize l and m textures for this row block's positions
190+
ivec4 l_tidx = ivec4(batch, head, seq_pos, 0);
191+
ivec3 l_pos = tidx_to_pos(l_tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim);
192+
vec4 l_texel = vec4(0.0);
193+
imageStore(t_l, l_pos, l_texel);
194+
195+
ivec4 m_tidx = ivec4(batch, head, seq_pos, 0);
196+
ivec3 m_pos = tidx_to_pos(m_tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim);
197+
vec4 m_texel = vec4(-1e10);
198+
imageStore(t_m, m_pos, m_texel);
199+
200+
// Initialize output tensor for this row block
201+
for (int dim = 0; dim < head_dim; dim++) {
202+
store_tensor_O(batch, seq_pos, head, dim, T(0.0));
203+
}
204+
}
205+
206+
// STEP 5: Outer loop over column blocks (For K, V tensors)
207+
const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks
208+
for (int j = 0; j < Tc; j++) {
209+
const int col_start = j * Bc;
210+
const int col_end = min(col_start + Bc, seq_len);
211+
const int actual_Bc = col_end - col_start;
212+
213+
// Load current statistics for all rows in this block
214+
float m_i[MAX_BR];
215+
float l_i[MAX_BR];
216+
for (int r = 0; r < actual_Br; r++) {
217+
const int seq_pos = row_start + r;
218+
m_i[r] = load_tensor_m(batch, head, seq_pos);
219+
l_i[r] = load_tensor_l(batch, head, seq_pos);
220+
}
221+
222+
// STEP 9: Compute Sij = Qi * Kj^T
223+
T S_block[MAX_BR][MAX_BC];
224+
float m_tilde_ij[MAX_BR]; // Row maxes
225+
float l_tilde_ij[MAX_BR]; // Row sums
226+
227+
// Initialize row statistics
228+
for (int r = 0; r < actual_Br; r++) {
229+
m_tilde_ij[r] = -1.0 / 0.0; // -infinity
230+
l_tilde_ij[r] = 0.0;
231+
}
232+
233+
// Compute attention scores Sij = Qi @ Kj^T
234+
for (int r = 0; r < actual_Br; r++) {
235+
const int global_row = row_start + r;
236+
for (int c = 0; c < actual_Bc; c++) {
237+
const int global_col = col_start + c;
238+
239+
// For multi-query attention: map query head to KV head
240+
const int kv_head = (head * num_kv_heads) / num_heads_val;
241+
242+
// Dot product: Q[seq_pos, :] · K[col_pos, :]
243+
T score = T(0.0);
244+
for (int dim = 0; dim < head_dim; dim++) {
245+
T q_val = load_tensor_Q(batch, global_row, head, dim);
246+
T k_val = load_tensor_K(batch, global_col, kv_head, dim);
247+
score += q_val * k_val;
248+
}
249+
score *= scale;
250+
251+
252+
// Apply causal masking: mask if global_col > global_row + input_pos
253+
bool masked = (global_col > global_row + input_pos);
254+
if (masked) {
255+
score = T(-1.0 / 0.0); // Set to negative infinity
256+
}
257+
258+
S_block[r][c] = score;
259+
260+
261+
// Track row maximum (after masking)
262+
m_tilde_ij[r] = max(m_tilde_ij[r], float(score));
263+
}
264+
}
265+
266+
// STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij)
267+
for (int r = 0; r < actual_Br; r++) {
268+
// Handle the case where all scores are -inf (fully masked row)
269+
if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) {
270+
// All scores are -inf, so all probabilities are 0
271+
for (int c = 0; c < actual_Bc; c++) {
272+
S_block[r][c] = 0.0;
273+
}
274+
l_tilde_ij[r] = 0.0;
275+
} else {
276+
// Normal case: compute softmax
277+
for (int c = 0; c < actual_Bc; c++) {
278+
S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r]));
279+
l_tilde_ij[r] += float(S_block[r][c]);
280+
}
281+
}
282+
}
283+
284+
// STEP 11: Softmax update
285+
float m_new_i[MAX_BR];
286+
float l_new_i[MAX_BR];
287+
for (int r = 0; r < actual_Br; r++) {
288+
m_new_i[r] = max(m_i[r], m_tilde_ij[r]);
289+
l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r];
290+
291+
}
292+
293+
// STEP 12: Update Oi
294+
for (int r = 0; r < actual_Br; r++) {
295+
const int global_row = row_start + r;
296+
float alpha = exp(m_i[r] - m_new_i[r]);
297+
float beta = exp(m_tilde_ij[r] - m_new_i[r]);
298+
299+
// For multi-query attention: map query head to KV head
300+
const int kv_head = (head * num_kv_heads) / num_heads_val;
301+
302+
for (int dim = 0; dim < head_dim; dim++) {
303+
// Compute P'ij @ Vj for this dimension
304+
T pv_sum = T(0.0);
305+
for (int c = 0; c < actual_Bc; c++) {
306+
const int global_col = col_start + c;
307+
T v_val = load_tensor_V(batch, global_col, kv_head, dim);
308+
pv_sum += S_block[r][c] * v_val;
309+
}
310+
311+
// Check for division by zero before updating output
312+
if (l_new_i[r] <= 0.0) {
313+
store_tensor_O(batch, global_row, head, dim, T(0.0));
314+
} else {
315+
// Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i
316+
T current_o = load_tensor_O(batch, global_row, head, dim);
317+
T new_o = (T(alpha) * T(l_i[r]) * current_o + T(beta) * pv_sum) / T(l_new_i[r]);
318+
store_tensor_O(batch, global_row, head, dim, new_o);
319+
320+
}
321+
}
322+
}
323+
324+
// STEP 13: Update li, mi
325+
for (int r = 0; r < actual_Br; r++) {
326+
const int seq_pos = row_start + r;
327+
store_tensor_l(batch, head, seq_pos, l_new_i[r]);
328+
store_tensor_m(batch, head, seq_pos, m_new_i[r]);
329+
}
330+
331+
}
332+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
flash_attention_texture3d:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: flash_attention_texture3d

0 commit comments

Comments
 (0)