|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#version 450 core |
| 10 | + |
| 11 | +#define PRECISION ${PRECISION} |
| 12 | +#define T ${buffer_scalar_type(DTYPE)} |
| 13 | +${define_required_extensions(DTYPE)} |
| 14 | + |
| 15 | +layout(std430) buffer; |
| 16 | + |
| 17 | +#include "indexing_utils.h" |
| 18 | + |
| 19 | +// Flash Attention inputs: Query, Key, Value tensors using texture storage |
| 20 | +${layout_declare_tensor(B, "rw", "t_O", DTYPE, "texture3d")} |
| 21 | +${layout_declare_tensor(B, "rw", "t_l", "float", "texture3d")} |
| 22 | +${layout_declare_tensor(B, "rw", "t_m", "float", "texture3d")} |
| 23 | +${layout_declare_tensor(B, "r", "t_Q", DTYPE, "texture3d")} |
| 24 | +${layout_declare_tensor(B, "r", "t_K", DTYPE, "texture3d")} |
| 25 | +${layout_declare_tensor(B, "r", "t_V", DTYPE, "texture3d")} |
| 26 | + |
| 27 | +${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] |
| 28 | +${layout_declare_ubo(B, "ivec4", "K_sizes")} |
| 29 | +${layout_declare_ubo(B, "ivec4", "V_sizes")} |
| 30 | +${layout_declare_ubo(B, "ivec4", "O_sizes")} |
| 31 | + |
| 32 | +${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] |
| 33 | +${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] |
| 34 | + |
| 35 | +${layout_declare_ubo(B, "float", "scale")} |
| 36 | +${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) |
| 37 | +${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) |
| 38 | +${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking |
| 39 | +${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads |
| 40 | +${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads |
| 41 | + |
| 42 | +// Axis mapping setup for proper texture indexing |
| 43 | +${layout_declare_spec_const(C, "int", "Q_layout", "DEFAULT_LAYOUT")} |
| 44 | +const lowp ivec4 Q_axis_map = unhash_axis_map(Q_layout); |
| 45 | +const lowp int Q_packed_dim = unhash_packed_dim(Q_layout); |
| 46 | + |
| 47 | +${layout_declare_spec_const(C, "int", "K_layout", "DEFAULT_LAYOUT")} |
| 48 | +const lowp ivec4 K_axis_map = unhash_axis_map(K_layout); |
| 49 | +const lowp int K_packed_dim = unhash_packed_dim(K_layout); |
| 50 | + |
| 51 | +${layout_declare_spec_const(C, "int", "V_layout", "DEFAULT_LAYOUT")} |
| 52 | +const lowp ivec4 V_axis_map = unhash_axis_map(V_layout); |
| 53 | +const lowp int V_packed_dim = unhash_packed_dim(V_layout); |
| 54 | + |
| 55 | +${layout_declare_spec_const(C, "int", "O_layout", "DEFAULT_LAYOUT")} |
| 56 | +const lowp ivec4 O_axis_map = unhash_axis_map(O_layout); |
| 57 | +const lowp int O_packed_dim = unhash_packed_dim(O_layout); |
| 58 | + |
| 59 | +${layout_declare_spec_const(C, "int", "l_layout", "DEFAULT_LAYOUT")} |
| 60 | +const lowp ivec4 l_axis_map = unhash_axis_map(l_layout); |
| 61 | +const lowp int l_packed_dim = unhash_packed_dim(l_layout); |
| 62 | + |
| 63 | +${layout_declare_spec_const(C, "int", "m_layout", "DEFAULT_LAYOUT")} |
| 64 | +const lowp ivec4 m_axis_map = unhash_axis_map(m_layout); |
| 65 | +const lowp int m_packed_dim = unhash_packed_dim(m_layout); |
| 66 | + |
| 67 | +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; |
| 68 | + |
| 69 | +// Maximum block sizes to prevent array overflow |
| 70 | +#define MAX_BR 64 |
| 71 | +#define MAX_BC 128 |
| 72 | + |
| 73 | +// Texture access helper functions using proper axis mapping |
| 74 | +// Q_sizes, K_sizes, V_sizes, O_sizes are [D, H, N, B] (UBO layout) |
| 75 | +// l_sizes, m_sizes are [B, H, N] (UBO layout) |
| 76 | +T load_tensor_Q(int batch, int seq_pos, int head, int dim) { |
| 77 | + ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order |
| 78 | + ivec3 pos = tidx_to_pos(tidx, Q_sizes, Q_axis_map, Q_packed_dim); |
| 79 | + int component = tidx[Q_packed_dim] % 4; |
| 80 | + vec4 texel = texelFetch(t_Q, pos, 0); |
| 81 | + return T(texel[component]); |
| 82 | +} |
| 83 | + |
| 84 | +T load_tensor_K(int batch, int seq_pos, int head, int dim) { |
| 85 | + ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order |
| 86 | + ivec3 pos = tidx_to_pos(tidx, K_sizes, K_axis_map, K_packed_dim); |
| 87 | + int component = tidx[K_packed_dim] % 4; |
| 88 | + vec4 texel = texelFetch(t_K, pos, 0); |
| 89 | + return T(texel[component]); |
| 90 | +} |
| 91 | + |
| 92 | +T load_tensor_V(int batch, int seq_pos, int head, int dim) { |
| 93 | + ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order |
| 94 | + ivec3 pos = tidx_to_pos(tidx, V_sizes, V_axis_map, V_packed_dim); |
| 95 | + int component = tidx[V_packed_dim] % 4; |
| 96 | + vec4 texel = texelFetch(t_V, pos, 0); |
| 97 | + return T(texel[component]); |
| 98 | +} |
| 99 | + |
| 100 | +T load_tensor_O(int batch, int seq_pos, int head, int dim) { |
| 101 | + ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order |
| 102 | + ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); |
| 103 | + int component = tidx[O_packed_dim] % 4; |
| 104 | + vec4 texel = imageLoad(t_O, pos); |
| 105 | + return T(texel[component]); |
| 106 | +} |
| 107 | + |
| 108 | +void store_tensor_O(int batch, int seq_pos, int head, int dim, T value) { |
| 109 | + ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order |
| 110 | + ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); |
| 111 | + int component = tidx[O_packed_dim] % 4; |
| 112 | + vec4 texel = imageLoad(t_O, pos); |
| 113 | + texel[component] = float(value); |
| 114 | + imageStore(t_O, pos, texel); |
| 115 | +} |
| 116 | + |
| 117 | +float load_tensor_l(int batch, int head, int seq_pos) { |
| 118 | + ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) |
| 119 | + ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); |
| 120 | + int component = tidx[l_packed_dim] % 4; |
| 121 | + vec4 texel = imageLoad(t_l, pos); |
| 122 | + return texel[component]; |
| 123 | +} |
| 124 | + |
| 125 | +void store_tensor_l(int batch, int head, int seq_pos, float value) { |
| 126 | + ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) |
| 127 | + ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); |
| 128 | + int component = tidx[l_packed_dim] % 4; |
| 129 | + vec4 texel = imageLoad(t_l, pos); |
| 130 | + texel[component] = value; |
| 131 | + imageStore(t_l, pos, texel); |
| 132 | +} |
| 133 | + |
| 134 | +float load_tensor_m(int batch, int head, int seq_pos) { |
| 135 | + ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) |
| 136 | + ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); |
| 137 | + int component = tidx[m_packed_dim] % 4; |
| 138 | + vec4 texel = imageLoad(t_m, pos); |
| 139 | + return texel[component]; |
| 140 | +} |
| 141 | + |
| 142 | +void store_tensor_m(int batch, int head, int seq_pos, float value) { |
| 143 | + ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) |
| 144 | + ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); |
| 145 | + int component = tidx[m_packed_dim] % 4; |
| 146 | + vec4 texel = imageLoad(t_m, pos); |
| 147 | + texel[component] = value; |
| 148 | + imageStore(t_m, pos, texel); |
| 149 | + |
| 150 | +} |
| 151 | + |
| 152 | +void main() { |
| 153 | + // Each thread processes one row block - same as buffer version |
| 154 | + const int thread_id = int(gl_GlobalInvocationID.x); |
| 155 | + |
| 156 | + // Tensor dimensions: Q_sizes = [D, H, N, B] |
| 157 | + const int head_dim = Q_sizes.x; // D (head dim) |
| 158 | + const int num_heads_val = Q_sizes.y; // H (num heads) |
| 159 | + const int seq_len = Q_sizes.z; // N (sequence length) |
| 160 | + const int batch_size = Q_sizes.w; // B (batch) |
| 161 | + |
| 162 | + // Block sizes |
| 163 | + const int Br = block_size_r; |
| 164 | + const int Bc = block_size_c; |
| 165 | + |
| 166 | + const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks |
| 167 | + const int total_row_blocks = batch_size * num_heads_val * Tr; |
| 168 | + |
| 169 | + if (thread_id >= total_row_blocks) { |
| 170 | + return; |
| 171 | + } |
| 172 | + |
| 173 | + // Decode thread_id to (batch, head, row_block) |
| 174 | + const int batch = thread_id / (num_heads_val * Tr); |
| 175 | + const int remaining = thread_id % (num_heads_val * Tr); |
| 176 | + const int head = remaining / Tr; |
| 177 | + const int row_block = remaining % Tr; |
| 178 | + |
| 179 | + // Calculate row range for this block |
| 180 | + const int row_start = row_block * Br; |
| 181 | + const int row_end = min(row_start + Br, seq_len); |
| 182 | + const int actual_Br = row_end - row_start; |
| 183 | + |
| 184 | + // STEP 1: Initialize only this thread's row block |
| 185 | + // Each thread initializes its own rows to avoid cross-workgroup synchronization issues |
| 186 | + for (int r = 0; r < actual_Br; r++) { |
| 187 | + const int seq_pos = row_start + r; |
| 188 | + |
| 189 | + // Initialize l and m textures for this row block's positions |
| 190 | + ivec4 l_tidx = ivec4(batch, head, seq_pos, 0); |
| 191 | + ivec3 l_pos = tidx_to_pos(l_tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); |
| 192 | + vec4 l_texel = vec4(0.0); |
| 193 | + imageStore(t_l, l_pos, l_texel); |
| 194 | + |
| 195 | + ivec4 m_tidx = ivec4(batch, head, seq_pos, 0); |
| 196 | + ivec3 m_pos = tidx_to_pos(m_tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); |
| 197 | + vec4 m_texel = vec4(-1e10); |
| 198 | + imageStore(t_m, m_pos, m_texel); |
| 199 | + |
| 200 | + // Initialize output tensor for this row block |
| 201 | + for (int dim = 0; dim < head_dim; dim++) { |
| 202 | + store_tensor_O(batch, seq_pos, head, dim, T(0.0)); |
| 203 | + } |
| 204 | + } |
| 205 | + |
| 206 | + // STEP 5: Outer loop over column blocks (For K, V tensors) |
| 207 | + const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks |
| 208 | + for (int j = 0; j < Tc; j++) { |
| 209 | + const int col_start = j * Bc; |
| 210 | + const int col_end = min(col_start + Bc, seq_len); |
| 211 | + const int actual_Bc = col_end - col_start; |
| 212 | + |
| 213 | + // Load current statistics for all rows in this block |
| 214 | + float m_i[MAX_BR]; |
| 215 | + float l_i[MAX_BR]; |
| 216 | + for (int r = 0; r < actual_Br; r++) { |
| 217 | + const int seq_pos = row_start + r; |
| 218 | + m_i[r] = load_tensor_m(batch, head, seq_pos); |
| 219 | + l_i[r] = load_tensor_l(batch, head, seq_pos); |
| 220 | + } |
| 221 | + |
| 222 | + // STEP 9: Compute Sij = Qi * Kj^T |
| 223 | + T S_block[MAX_BR][MAX_BC]; |
| 224 | + float m_tilde_ij[MAX_BR]; // Row maxes |
| 225 | + float l_tilde_ij[MAX_BR]; // Row sums |
| 226 | + |
| 227 | + // Initialize row statistics |
| 228 | + for (int r = 0; r < actual_Br; r++) { |
| 229 | + m_tilde_ij[r] = -1.0 / 0.0; // -infinity |
| 230 | + l_tilde_ij[r] = 0.0; |
| 231 | + } |
| 232 | + |
| 233 | + // Compute attention scores Sij = Qi @ Kj^T |
| 234 | + for (int r = 0; r < actual_Br; r++) { |
| 235 | + const int global_row = row_start + r; |
| 236 | + for (int c = 0; c < actual_Bc; c++) { |
| 237 | + const int global_col = col_start + c; |
| 238 | + |
| 239 | + // For multi-query attention: map query head to KV head |
| 240 | + const int kv_head = (head * num_kv_heads) / num_heads_val; |
| 241 | + |
| 242 | + // Dot product: Q[seq_pos, :] · K[col_pos, :] |
| 243 | + T score = T(0.0); |
| 244 | + for (int dim = 0; dim < head_dim; dim++) { |
| 245 | + T q_val = load_tensor_Q(batch, global_row, head, dim); |
| 246 | + T k_val = load_tensor_K(batch, global_col, kv_head, dim); |
| 247 | + score += q_val * k_val; |
| 248 | + } |
| 249 | + score *= scale; |
| 250 | + |
| 251 | + |
| 252 | + // Apply causal masking: mask if global_col > global_row + input_pos |
| 253 | + bool masked = (global_col > global_row + input_pos); |
| 254 | + if (masked) { |
| 255 | + score = T(-1.0 / 0.0); // Set to negative infinity |
| 256 | + } |
| 257 | + |
| 258 | + S_block[r][c] = score; |
| 259 | + |
| 260 | + |
| 261 | + // Track row maximum (after masking) |
| 262 | + m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); |
| 263 | + } |
| 264 | + } |
| 265 | + |
| 266 | + // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) |
| 267 | + for (int r = 0; r < actual_Br; r++) { |
| 268 | + // Handle the case where all scores are -inf (fully masked row) |
| 269 | + if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { |
| 270 | + // All scores are -inf, so all probabilities are 0 |
| 271 | + for (int c = 0; c < actual_Bc; c++) { |
| 272 | + S_block[r][c] = 0.0; |
| 273 | + } |
| 274 | + l_tilde_ij[r] = 0.0; |
| 275 | + } else { |
| 276 | + // Normal case: compute softmax |
| 277 | + for (int c = 0; c < actual_Bc; c++) { |
| 278 | + S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); |
| 279 | + l_tilde_ij[r] += float(S_block[r][c]); |
| 280 | + } |
| 281 | + } |
| 282 | + } |
| 283 | + |
| 284 | + // STEP 11: Softmax update |
| 285 | + float m_new_i[MAX_BR]; |
| 286 | + float l_new_i[MAX_BR]; |
| 287 | + for (int r = 0; r < actual_Br; r++) { |
| 288 | + m_new_i[r] = max(m_i[r], m_tilde_ij[r]); |
| 289 | + l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; |
| 290 | + |
| 291 | + } |
| 292 | + |
| 293 | + // STEP 12: Update Oi |
| 294 | + for (int r = 0; r < actual_Br; r++) { |
| 295 | + const int global_row = row_start + r; |
| 296 | + float alpha = exp(m_i[r] - m_new_i[r]); |
| 297 | + float beta = exp(m_tilde_ij[r] - m_new_i[r]); |
| 298 | + |
| 299 | + // For multi-query attention: map query head to KV head |
| 300 | + const int kv_head = (head * num_kv_heads) / num_heads_val; |
| 301 | + |
| 302 | + for (int dim = 0; dim < head_dim; dim++) { |
| 303 | + // Compute P'ij @ Vj for this dimension |
| 304 | + T pv_sum = T(0.0); |
| 305 | + for (int c = 0; c < actual_Bc; c++) { |
| 306 | + const int global_col = col_start + c; |
| 307 | + T v_val = load_tensor_V(batch, global_col, kv_head, dim); |
| 308 | + pv_sum += S_block[r][c] * v_val; |
| 309 | + } |
| 310 | + |
| 311 | + // Check for division by zero before updating output |
| 312 | + if (l_new_i[r] <= 0.0) { |
| 313 | + store_tensor_O(batch, global_row, head, dim, T(0.0)); |
| 314 | + } else { |
| 315 | + // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i |
| 316 | + T current_o = load_tensor_O(batch, global_row, head, dim); |
| 317 | + T new_o = (T(alpha) * T(l_i[r]) * current_o + T(beta) * pv_sum) / T(l_new_i[r]); |
| 318 | + store_tensor_O(batch, global_row, head, dim, new_o); |
| 319 | + |
| 320 | + } |
| 321 | + } |
| 322 | + } |
| 323 | + |
| 324 | + // STEP 13: Update li, mi |
| 325 | + for (int r = 0; r < actual_Br; r++) { |
| 326 | + const int seq_pos = row_start + r; |
| 327 | + store_tensor_l(batch, head, seq_pos, l_new_i[r]); |
| 328 | + store_tensor_m(batch, head, seq_pos, m_new_i[r]); |
| 329 | + } |
| 330 | + |
| 331 | + } |
| 332 | +} |
0 commit comments