diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 68d7f90d09c..7cb22d90f60 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -971,6 +971,13 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row + # "Classic" Operator tests + PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build + # TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are + # failing due to to the libstdc++.so.6 installed with conda not supporting + # GLIBCXX_3.4.30. These tests are still run in Meta internal CI. + # ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test + # Run e2e testing for selected operators. More operators will be tested via this # route in the future. python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*" diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 934e02eb7be..9f1561fb05e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -571,7 +571,7 @@ def register_sdpa_with_kv_cache_op(): ) def register_sdpa_ops(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl deleted file mode 100644 index 8509fdf1f49..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.glsl +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -// Flash Attention inputs: Query, Key, Value tensors -${layout_declare_tensor(B, "rw", "t_O", DTYPE, "buffer")} -${layout_declare_tensor(B, "rw", "t_l", "float", "buffer")} -${layout_declare_tensor(B, "rw", "t_m", "float", "buffer")} -${layout_declare_tensor(B, "r", "t_Q", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_K", DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_V", DTYPE, "buffer")} - -${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] -${layout_declare_ubo(B, "ivec4", "K_sizes")} -${layout_declare_ubo(B, "ivec4", "V_sizes")} -${layout_declare_ubo(B, "ivec4", "O_sizes")} - -${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] -${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] - -${layout_declare_ubo(B, "float", "scale")} -${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) -${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) -${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking -${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads -${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Maximum block sizes to prevent array overflow -#define MAX_BR 64 -#define MAX_BC 128 - -void main() { - // Each thread processes one row block - const int thread_id = int(gl_GlobalInvocationID.x); - - // Tensor dimensions: Q_sizes = [D, H, N, B] from graph.sizes_ubo() - // The UBO layout is different from the PyTorch tensor layout - const int head_dim = Q_sizes.x; // D (head dim) - const int num_heads = Q_sizes.y; // H (num heads) - const int seq_len = Q_sizes.z; // N (sequence length) - const int batch_size = Q_sizes.w; // B (batch) - - // Block sizes - const int Br = block_size_r; - const int Bc = block_size_c; - - const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks - const int total_row_blocks = batch_size * num_heads * Tr; - - if (thread_id >= total_row_blocks) { - return; - } - - // Decode thread_id to (batch, head, row_block) - const int batch = thread_id / (num_heads * Tr); - const int remaining = thread_id % (num_heads * Tr); - const int head = remaining / Tr; - const int row_block = remaining % Tr; - - // Calculate row range for this block - const int row_start = row_block * Br; - const int row_end = min(row_start + Br, seq_len); - const int actual_Br = row_end - row_start; - - // Base indices for this batch - const int q_base = batch * (seq_len * num_heads * head_dim); - const int k_base = batch * (seq_len * num_heads * head_dim); - const int v_base = batch * (seq_len * num_heads * head_dim); - const int o_base = batch * (seq_len * num_heads * head_dim); - const int lm_base = batch * (seq_len * num_heads); - - // STEP 2: Initialize O = 0, l = 0, m = -inf for this row block - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - - t_l[lm_idx] = 0.0; - t_m[lm_idx] = -1.0 / 0.0; // -infinity - - for (int dim = 0; dim < head_dim; dim++) { - const int o_idx = o_base + seq_pos * (num_heads * head_dim) + head * head_dim + dim; - t_O[o_idx] = T(0.0); - } - } - - // STEP 5: Outer loop over column blocks (For K, V tensors) - const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks - for (int j = 0; j < Tc; j++) { - const int col_start = j * Bc; - const int col_end = min(col_start + Bc, seq_len); - const int actual_Bc = col_end - col_start; - - // STEP 6-8 done implicitly below - - // Load current statistics for all rows in this block - float m_i[MAX_BR]; - float l_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - m_i[r] = t_m[lm_idx]; - l_i[r] = t_l[lm_idx]; - } - - // STEP 9: Compute Sij = Qi * Kj^T - T S_block[MAX_BR][MAX_BC]; // Use MAX_BR and MAX_BC constants - float m_tilde_ij[MAX_BR]; // Row maxes (float to match l/m) - float l_tilde_ij[MAX_BR]; // Row sums (float to match l/m) - - // Initialize row statistics - for (int r = 0; r < actual_Br; r++) { - m_tilde_ij[r] = -1.0 / 0.0; // -infinity - l_tilde_ij[r] = 0.0; - } - - // Compute attention scores Sij = Qi @ Kj^T - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads; - - // Dot product: Q[seq_pos, :] · K[col_pos, :] - T score = T(0.0); - for (int dim = 0; dim < head_dim; dim++) { - const int q_idx = q_base + global_row * (num_heads * head_dim) + head * head_dim + dim; - const int k_idx = k_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; - score += t_Q[q_idx] * t_K[k_idx]; - } - score *= scale; - - - // Apply causal masking: mask if global_col > global_row + input_pos - if (global_col > global_row + input_pos) { - score = T(-1.0 / 0.0); // Set to negative infinity - } - - S_block[r][c] = score; - - // Track row maximum (after masking) - m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); - } - } - - // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) - for (int r = 0; r < actual_Br; r++) { - // Handle the case where all scores are -inf (fully masked row) - if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { - // All scores are -inf, so all probabilities are 0 - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = T(0.0); - } - l_tilde_ij[r] = 0.0; - } else { - // Normal case: compute softmax - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); - l_tilde_ij[r] += float(S_block[r][c]); - } - } - } - - // STEP 11: Softmax update - float m_new_i[MAX_BR]; - float l_new_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - m_new_i[r] = max(m_i[r], m_tilde_ij[r]); - - l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; - } - - // STEP 12: Update Oi - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - float alpha = exp(m_i[r] - m_new_i[r]); - float beta = exp(m_tilde_ij[r] - m_new_i[r]); - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads; - - for (int dim = 0; dim < head_dim; dim++) { - const int o_idx = o_base + global_row * (num_heads * head_dim) + head * head_dim + dim; - - // Compute P'ij @ Vj for this dimension - T pv_sum = T(0.0); - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - const int v_idx = v_base + global_col * (num_kv_heads * head_dim) + kv_head * head_dim + dim; - pv_sum += S_block[r][c] * t_V[v_idx]; - } - - // Check for division by zero before updating output - if (l_new_i[r] <= 0.0) { - t_O[o_idx] = T(0.0); // Set to zero to avoid NaN - } else { - // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i - t_O[o_idx] = (T(alpha) * T(l_i[r]) * t_O[o_idx] + T(beta) * pv_sum) / T(l_new_i[r]); - } - } - } - - // STEP 13: Update li, mi - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - const int lm_idx = lm_base + head * seq_len + seq_pos; - t_l[lm_idx] = l_new_i[r]; - t_m[lm_idx] = m_new_i[r]; - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl deleted file mode 100644 index 1f72a583410..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.glsl +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} -#define T ${buffer_scalar_type(DTYPE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -// Flash Attention inputs: Query, Key, Value tensors using texture storage -${layout_declare_tensor(B, "rw", "t_O", DTYPE, "texture3d")} -${layout_declare_tensor(B, "rw", "t_l", "float", "texture3d")} -${layout_declare_tensor(B, "rw", "t_m", "float", "texture3d")} -${layout_declare_tensor(B, "r", "t_Q", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_K", DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_V", DTYPE, "texture3d")} - -${layout_declare_ubo(B, "ivec4", "Q_sizes")} // [B, H, N, D] -${layout_declare_ubo(B, "ivec4", "K_sizes")} -${layout_declare_ubo(B, "ivec4", "V_sizes")} -${layout_declare_ubo(B, "ivec4", "O_sizes")} - -${layout_declare_ubo(B, "ivec3", "l_sizes")} // [B, H, N] -${layout_declare_ubo(B, "ivec3", "m_sizes")} // [B, H, N] - -${layout_declare_ubo(B, "float", "scale")} -${layout_declare_ubo(B, "int", "block_size_r")} // Br (num rows in Q block) -${layout_declare_ubo(B, "int", "block_size_c")} // Bc (num cols in K/V block) -${layout_declare_ubo(B, "int", "input_pos")} // Starting position for causal masking -${layout_declare_ubo(B, "int", "num_heads")} // Number of query heads -${layout_declare_ubo(B, "int", "num_kv_heads")} // Number of key/value heads - -// Axis mapping setup for proper texture indexing -${layout_declare_spec_const(C, "int", "Q_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 Q_axis_map = unhash_axis_map(Q_layout); -const lowp int Q_packed_dim = unhash_packed_dim(Q_layout); - -${layout_declare_spec_const(C, "int", "K_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 K_axis_map = unhash_axis_map(K_layout); -const lowp int K_packed_dim = unhash_packed_dim(K_layout); - -${layout_declare_spec_const(C, "int", "V_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 V_axis_map = unhash_axis_map(V_layout); -const lowp int V_packed_dim = unhash_packed_dim(V_layout); - -${layout_declare_spec_const(C, "int", "O_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 O_axis_map = unhash_axis_map(O_layout); -const lowp int O_packed_dim = unhash_packed_dim(O_layout); - -${layout_declare_spec_const(C, "int", "l_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 l_axis_map = unhash_axis_map(l_layout); -const lowp int l_packed_dim = unhash_packed_dim(l_layout); - -${layout_declare_spec_const(C, "int", "m_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 m_axis_map = unhash_axis_map(m_layout); -const lowp int m_packed_dim = unhash_packed_dim(m_layout); - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Maximum block sizes to prevent array overflow -#define MAX_BR 64 -#define MAX_BC 128 - -// Texture access helper functions using proper axis mapping -// Q_sizes, K_sizes, V_sizes, O_sizes are [D, H, N, B] (UBO layout) -// l_sizes, m_sizes are [B, H, N] (UBO layout) -T load_tensor_Q(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, Q_sizes, Q_axis_map, Q_packed_dim); - int component = tidx[Q_packed_dim] % 4; - vec4 texel = texelFetch(t_Q, pos, 0); - return T(texel[component]); -} - -T load_tensor_K(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, K_sizes, K_axis_map, K_packed_dim); - int component = tidx[K_packed_dim] % 4; - vec4 texel = texelFetch(t_K, pos, 0); - return T(texel[component]); -} - -T load_tensor_V(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, V_sizes, V_axis_map, V_packed_dim); - int component = tidx[V_packed_dim] % 4; - vec4 texel = texelFetch(t_V, pos, 0); - return T(texel[component]); -} - -T load_tensor_O(int batch, int seq_pos, int head, int dim) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); - int component = tidx[O_packed_dim] % 4; - vec4 texel = imageLoad(t_O, pos); - return T(texel[component]); -} - -void store_tensor_O(int batch, int seq_pos, int head, int dim, T value) { - ivec4 tidx = ivec4(dim, head, seq_pos, batch); // Match [D, H, N, B] order - ivec3 pos = tidx_to_pos(tidx, O_sizes, O_axis_map, O_packed_dim); - int component = tidx[O_packed_dim] % 4; - vec4 texel = imageLoad(t_O, pos); - texel[component] = float(value); - imageStore(t_O, pos, texel); -} - -float load_tensor_l(int batch, int head, int seq_pos) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - int component = tidx[l_packed_dim] % 4; - vec4 texel = imageLoad(t_l, pos); - return texel[component]; -} - -void store_tensor_l(int batch, int head, int seq_pos, float value) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - int component = tidx[l_packed_dim] % 4; - vec4 texel = imageLoad(t_l, pos); - texel[component] = value; - imageStore(t_l, pos, texel); -} - -float load_tensor_m(int batch, int head, int seq_pos) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - int component = tidx[m_packed_dim] % 4; - vec4 texel = imageLoad(t_m, pos); - return texel[component]; -} - -void store_tensor_m(int batch, int head, int seq_pos, float value) { - ivec4 tidx = ivec4(seq_pos, head, batch, 0); // Match [N, H, B] order (with padding) - ivec3 pos = tidx_to_pos(tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - int component = tidx[m_packed_dim] % 4; - vec4 texel = imageLoad(t_m, pos); - texel[component] = value; - imageStore(t_m, pos, texel); - -} - -void main() { - // Each thread processes one row block - same as buffer version - const int thread_id = int(gl_GlobalInvocationID.x); - - // Tensor dimensions: Q_sizes = [D, H, N, B] - const int head_dim = Q_sizes.x; // D (head dim) - const int num_heads_val = Q_sizes.y; // H (num heads) - const int seq_len = Q_sizes.z; // N (sequence length) - const int batch_size = Q_sizes.w; // B (batch) - - // Block sizes - const int Br = block_size_r; - const int Bc = block_size_c; - - const int Tr = (seq_len + Br - 1) / Br; // Number of row blocks - const int total_row_blocks = batch_size * num_heads_val * Tr; - - if (thread_id >= total_row_blocks) { - return; - } - - // Decode thread_id to (batch, head, row_block) - const int batch = thread_id / (num_heads_val * Tr); - const int remaining = thread_id % (num_heads_val * Tr); - const int head = remaining / Tr; - const int row_block = remaining % Tr; - - // Calculate row range for this block - const int row_start = row_block * Br; - const int row_end = min(row_start + Br, seq_len); - const int actual_Br = row_end - row_start; - - // STEP 1: Initialize only this thread's row block - // Each thread initializes its own rows to avoid cross-workgroup synchronization issues - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - - // Initialize l and m textures for this row block's positions - ivec4 l_tidx = ivec4(batch, head, seq_pos, 0); - ivec3 l_pos = tidx_to_pos(l_tidx, ivec4(l_sizes, 1), l_axis_map, l_packed_dim); - vec4 l_texel = vec4(0.0); - imageStore(t_l, l_pos, l_texel); - - ivec4 m_tidx = ivec4(batch, head, seq_pos, 0); - ivec3 m_pos = tidx_to_pos(m_tidx, ivec4(m_sizes, 1), m_axis_map, m_packed_dim); - vec4 m_texel = vec4(-1e10); - imageStore(t_m, m_pos, m_texel); - - // Initialize output tensor for this row block - for (int dim = 0; dim < head_dim; dim++) { - store_tensor_O(batch, seq_pos, head, dim, T(0.0)); - } - } - - // STEP 5: Outer loop over column blocks (For K, V tensors) - const int Tc = (seq_len + Bc - 1) / Bc; // Number of column blocks - for (int j = 0; j < Tc; j++) { - const int col_start = j * Bc; - const int col_end = min(col_start + Bc, seq_len); - const int actual_Bc = col_end - col_start; - - // Load current statistics for all rows in this block - float m_i[MAX_BR]; - float l_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - m_i[r] = load_tensor_m(batch, head, seq_pos); - l_i[r] = load_tensor_l(batch, head, seq_pos); - } - - // STEP 9: Compute Sij = Qi * Kj^T - T S_block[MAX_BR][MAX_BC]; - float m_tilde_ij[MAX_BR]; // Row maxes - float l_tilde_ij[MAX_BR]; // Row sums - - // Initialize row statistics - for (int r = 0; r < actual_Br; r++) { - m_tilde_ij[r] = -1.0 / 0.0; // -infinity - l_tilde_ij[r] = 0.0; - } - - // Compute attention scores Sij = Qi @ Kj^T - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads_val; - - // Dot product: Q[seq_pos, :] · K[col_pos, :] - T score = T(0.0); - for (int dim = 0; dim < head_dim; dim++) { - T q_val = load_tensor_Q(batch, global_row, head, dim); - T k_val = load_tensor_K(batch, global_col, kv_head, dim); - score += q_val * k_val; - } - score *= scale; - - - // Apply causal masking: mask if global_col > global_row + input_pos - bool masked = (global_col > global_row + input_pos); - if (masked) { - score = T(-1.0 / 0.0); // Set to negative infinity - } - - S_block[r][c] = score; - - - // Track row maximum (after masking) - m_tilde_ij[r] = max(m_tilde_ij[r], float(score)); - } - } - - // STEP 10: Compute P'ij = exp(Sij − m'ij) and l'ij = rowsum(P'ij) - for (int r = 0; r < actual_Br; r++) { - // Handle the case where all scores are -inf (fully masked row) - if (isinf(m_tilde_ij[r]) && m_tilde_ij[r] < 0.0) { - // All scores are -inf, so all probabilities are 0 - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = 0.0; - } - l_tilde_ij[r] = 0.0; - } else { - // Normal case: compute softmax - for (int c = 0; c < actual_Bc; c++) { - S_block[r][c] = exp(S_block[r][c] - T(m_tilde_ij[r])); - l_tilde_ij[r] += float(S_block[r][c]); - } - } - } - - // STEP 11: Softmax update - float m_new_i[MAX_BR]; - float l_new_i[MAX_BR]; - for (int r = 0; r < actual_Br; r++) { - m_new_i[r] = max(m_i[r], m_tilde_ij[r]); - l_new_i[r] = exp(m_i[r] - m_new_i[r]) * l_i[r] + exp(m_tilde_ij[r] - m_new_i[r]) * l_tilde_ij[r]; - - } - - // STEP 12: Update Oi - for (int r = 0; r < actual_Br; r++) { - const int global_row = row_start + r; - float alpha = exp(m_i[r] - m_new_i[r]); - float beta = exp(m_tilde_ij[r] - m_new_i[r]); - - // For multi-query attention: map query head to KV head - const int kv_head = (head * num_kv_heads) / num_heads_val; - - for (int dim = 0; dim < head_dim; dim++) { - // Compute P'ij @ Vj for this dimension - T pv_sum = T(0.0); - for (int c = 0; c < actual_Bc; c++) { - const int global_col = col_start + c; - T v_val = load_tensor_V(batch, global_col, kv_head, dim); - pv_sum += S_block[r][c] * v_val; - } - - // Check for division by zero before updating output - if (l_new_i[r] <= 0.0) { - store_tensor_O(batch, global_row, head, dim, T(0.0)); - } else { - // Oi = (alpha * l_i * Oi + beta * P'ij @ Vj) / l_new_i - T current_o = load_tensor_O(batch, global_row, head, dim); - T new_o = (T(alpha) * T(l_i[r]) * current_o + T(beta) * pv_sum) / T(l_new_i[r]); - store_tensor_O(batch, global_row, head, dim, new_o); - - } - } - } - - // STEP 13: Update li, mi - for (int r = 0; r < actual_Br; r++) { - const int seq_pos = row_start + r; - store_tensor_l(batch, head, seq_pos, l_new_i[r]); - store_tensor_m(batch, head, seq_pos, m_new_i[r]); - } - - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl deleted file mode 100644 index 8028362c3e5..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.glsl +++ /dev/null @@ -1,80 +0,0 @@ -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "cache", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "projected", DTYPE, STORAGE)} -$if STORAGE == "buffer": - ${layout_declare_ubo(B, "int", "projected_numel")} - ${layout_declare_ubo(B, "ivec4", "cache_strides")} - ${layout_declare_ubo(B, "int", "input_pos")} -$else: - ${layout_declare_ubo(B, "ivec3", "projected_limits")} - ${layout_declare_ubo(B, "int", "input_pos")} - - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - * t_cache will have sizes of (max_batch_size, max_seq_len, n_heads, head_dim). - * t_projected will have sizes of (batch_size, seq_len, n_heads, head_dim). - * - * The cache update inserts the values of t_projected into t_cache at the index - * specified by input_pos at the seq_len dimension. It is equivalent to calling - - * t_cache = t_cache.slice_scatter( - * t_projected, dim=1, start=input_pos, end=input_pos+seq_len) - * - * Note that this shader is implemented assuming that max_batch_size is 1. - */ - -#ifdef USING_BUFFER - -/*************************** - ** Buffer Implementation ** - ***************************/ - -void main() { - int projected_bufi = int(gl_GlobalInvocationID.x); - // Bump cache index forward by input_pos elements along the seq_len dimension. - // cache_strides contains the strides of the cache tensor. - int cache_bufi = input_pos * cache_strides.z + projected_bufi; - if (projected_bufi >= projected_numel) { - return; - } - cache[cache_bufi] = projected[projected_bufi]; -} - -#else - -/**************************** - ** Texture Implementation ** - ****************************/ - -// Note that this shader assumes the that tensors are width packed, i.e. -// packed_dim = 0 -void main() { - const ivec3 projected_pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(projected_pos, projected_limits))) { - return; - } - - const ivec3 cache_pos = ivec3( - projected_pos.x, - projected_pos.y, - projected_pos.z + input_pos); - - write_texel(cache, cache_pos, load_texel(projected, projected_pos)); -} - -#endif // USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl deleted file mode 100644 index 1e854bf7f85..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.glsl +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} -${define_required_extensions(DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "rw", "attn_weight", DTYPE, STORAGE)} - -$if STORAGE == "buffer": - ${layout_declare_ubo(B, "ivec4", "attn_weight_sizes")} - ${layout_declare_ubo(B, "ivec4", "attn_weight_strides")} -$else: - ${layout_declare_ubo(B, "ivec3", "attn_weight_limits")} - -${layout_declare_ubo(B, "int", "input_pos")} -${layout_declare_ubo(B, "float", "scale")} - - -#include "indexing_utils.h" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// Negative infinity is represented by having sign bit be 1, all exponent bits -// be 1, all mantissa bits be 0. -#define NEGATIVE_INF_BITS 0xFF800000 -const float negative_infinity = NEGATIVE_INF_BITS; - -#ifdef USING_BUFFER - -/* - * This implementations applies a scale and mask to the attention weight tensor - * of an SDPA block. The sizes of the attention weight is - * (batch_size, n_heads, seq_len, input_pos + seq_len) - * Conceptually the weights represent the relationship between each token in the - * sequence with each token preceding it. - * - * The scale applied is 1.0 / sqrt(head_dim_length) - * - * The mask applied is a bit more complicated. Imagine you create a square - * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the - * lower triangular section of the matrix to -inf. Then, slice the matrix along - * the row dimension starting from input_pos to input_pos + seq_len. You end up - * with a partial mask with size (seq_len, input_pos + seq_len). This is the - * mask that is applied to the attention weight. - * - * In the shader, instead of generating the mask, the index of the elment is - * inspected to determine if it would have been masked. Given an element at - * tensor index (n, c, h, w), it would be masked if w < h + input_pos. - */ - -/*************************** - ** Buffer Implementation ** - ***************************/ - -void main() { - const ivec4 attn_weight_idx = ivec4( - gl_GlobalInvocationID.x, - gl_GlobalInvocationID.y, - gl_GlobalInvocationID.z, - 0); - - if (any(greaterThanEqual(attn_weight_idx, attn_weight_sizes))) { - return; - } - - const T scale_conv = T(scale); - - const int attn_weight_id = tidx_to_bufi(attn_weight_idx, attn_weight_strides); - if (attn_weight_idx.x <= attn_weight_idx.y + input_pos) { - attn_weight[attn_weight_id] = attn_weight[attn_weight_id] * scale_conv; - } else { - attn_weight[attn_weight_id] = T(negative_infinity); - } -} - -#else - -/**************************** - ** Texture Implementation ** - ****************************/ - -/* - * This implementation assumes that the attention weight is width packed, i.e. - * the packed dim of the attn_weight is 0. - */ -void main() { - const ivec3 attn_weight_pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(attn_weight_pos, attn_weight_limits))) { - return; - } - - vec4 outtex = imageLoad(attn_weight, attn_weight_pos) * scale; - - // Mask out the upper triangular of attn_weight to -inf - [[unroll]] for (int i = 0; i < 4; ++i) { - if (attn_weight_pos.x * 4 + i > attn_weight_pos.y + input_pos) { - outtex[i] = negative_infinity; - } - } - - write_texel(attn_weight, attn_weight_pos, outtex); -} - -#endif // USING_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml deleted file mode 100644 index ca8806fe000..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weight_scale_and_mask.yaml +++ /dev/null @@ -1,13 +0,0 @@ -sdpa_attn_weight_scale_and_mask: - parameter_names_with_default_values: - DTYPE: float - STORAGE: buffer - generate_variant_forall: - STORAGE: - - VALUE: buffer - - VALUE: texture3d - DTYPE: - - VALUE: half - - VALUE: float - shader_variants: - - NAME: sdpa_attn_weight_scale_and_mask diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl new file mode 100644 index 00000000000..1dff0017f30 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.glsl @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define NUM_WORKERS_PER_WG 64 + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights_softmax", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Shared memory for cooperative exp sum finding +shared T shared_exp_sum[NUM_WORKERS_PER_WG]; + +VEC4_T load_attn_weights_c4( + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef USING_BUFFER + return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; +#else + return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); +#endif +} + +void store_attn_weights_softmax_c4( + const VEC4_T out_texel, + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef USING_BUFFER + t_attn_weights_softmax[(q_h * S * C4) + (s * C4) + c4] = out_texel; +#else + imageStore(t_attn_weights_softmax, ivec3(c4, s, q_h), out_texel); +#endif +} + +void main() { + const int worker_id = int(gl_LocalInvocationID.x); + + // Index along attention weight's sequence_len dim + const int s = int(gl_GlobalInvocationID.y); + // idx along attention weight's num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + // manually determine size of the context_len dim of the attention weight. + // The "actual" tensor sizes may have been aligned to a multiple of 4 to allow + // memory loads to be aligned to texel boundaries. + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + if (s >= S || q_h >= Q_H) { + return; + } + + // Initialize thread-local min/max + T local_exp_sum = 0; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_4(context_len_aligned_down); + + // Each thread processes elements along a context_len row with a stride of the + // number of threads in the work group. + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + for (int comp = 0; comp < 4; comp++) { + local_exp_sum += exp(in_texel[comp]); + } + } + // First thread in the work group responsible for handling last texel if it + // contains any padded elements + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { + const int c_base = mul_4(c4); + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + [[unroll]] for (int comp = 0; comp < 4; comp++) { + if (c_base + comp < context_len) { + local_exp_sum += exp(in_texel[comp]); + } + } + } + } + + // Store thread-local results in shared memory + shared_exp_sum[worker_id] = local_exp_sum; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result + for (int i = NUM_WORKERS_PER_WG / 2; i > 0; i >>= 1) { + if (worker_id < i) { + shared_exp_sum[worker_id] = shared_exp_sum[worker_id] + + shared_exp_sum[worker_id + i]; + } + memoryBarrierShared(); + barrier(); + } + + local_exp_sum = shared_exp_sum[0]; + // Now go back through each element in the row and normalize + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_WG) { + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + VEC4_T out_texel = exp(in_texel) / local_exp_sum; + store_attn_weights_softmax_c4( + out_texel, c4, s, q_h, context_texel_len, S, Q_H); + } + // First thread in the work group responsible for handling last texel if it + // contains any padded elements + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; ++c4) { + const int c_base = mul_4(c4); + VEC4_T in_texel = load_attn_weights_c4( + c4, s, q_h, context_texel_len, S, Q_H); + + // Ensure that padding elements are set to 0. + VEC4_T out_texel = VEC4_T(0); + [[unroll]] for (int comp = 0; comp < 4; comp++) { + if (c_base + comp < context_len) { + out_texel[comp] = exp(in_texel[comp]) / local_exp_sum; + } + } + store_attn_weights_softmax_c4( + out_texel, c4, s, q_h, context_texel_len, S, Q_H); + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml similarity index 82% rename from backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml index e2a96234465..8abf50399e0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_attn_weights_softmax.yaml @@ -4,16 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -kv_cache_update: +sdpa_attn_weights_softmax: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + STORAGE: texture3d generate_variant_forall: STORAGE: - - VALUE: buffer - VALUE: texture3d + - VALUE: buffer DTYPE: - - VALUE: half - VALUE: float shader_variants: - - NAME: kv_cache_update + - NAME: sdpa_attn_weights_softmax diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl new file mode 100644 index 00000000000..4b7e3e0ddd2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.glsl @@ -0,0 +1,213 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if K_CACHE_STORAGE == "buffer": + #define K_CACHE_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define NUM_WORKERS_PER_OUT 64 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_q_projected_tile_load.glslh" +#include "sdpa_fp_k_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_attn_weight_tile_store.glslh" + +shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; + +/* + * See the tiled variant of this shader for the implemented behavior. This + * shader is implements an optimization for cases where sequence length is 1; in + * these cases, the matrix multiplication being performed is akin to gemv, which + * benefits from using a co-operative algorithm for reduction. For this shader + * the entire work group co-operates to compute one reduction output. + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int worker_id = int(gl_LocalInvocationID.y); + + const int tile_idx_x = int(gl_GlobalInvocationID.x); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output context_len dim + const int c = tile_idx_x * TILE_N; + const int c4 = div_4(c); + + // idx along the output seq_len dim. Note that for this shader seq_len will be + // 1. + const int s = 0; + + // texel size of head_dim, over which the dot product is accumulated + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = k_cache_sizes.y; + // Max context length + const int C = k_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (c >= context_len || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile q_tile; + FPWeightTile w_tile; + + // If the tile is completely inside the mask region, then there is no need to + // compute the output tile. All the elements in the output tile can be set to + // negative infinity. + bool tile_in_mask_region = c > (input_pos + s + (TILE_M - 1)); + if (tile_in_mask_region) { + const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); + set_out_tile_to_vec(out_tile, negative_infinity_vec); + } + // Otherwise, need to actually compute output tile + else { + const bool dont_check_bounds = (S - s) >= TILE_M && + (context_len - c) >= TILE_N; + + if (dont_check_bounds) { + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_no_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } else { + for (int d4 = worker_id; d4 < D4; d4 += NUM_WORKERS_PER_OUT) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } + } + + partial_sums[worker_id] = out_tile; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = NUM_WORKERS_PER_OUT / 2; i > 0; i /= 2) { + if (worker_id < i) { + accumulate_out_tile_with_out_tile( + partial_sums[worker_id], partial_sums[worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread will write out the result + if (worker_id == 0) { + out_tile = partial_sums[0]; + // Apply scale and mask if the tile was not entirely in the mask region + if (!tile_in_mask_region) { + VEC4_T inv_scale_vec = VEC4_T(inv_scale); + apply_scale_and_mask( + out_tile, + inv_scale_vec, + input_pos, + c, + s); + } + + store_attn_weight_tile_with_checks( + out_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml new file mode 100644 index 00000000000..6a4cffcc913 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_attn_weights_coop: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + K_CACHE_STORAGE: texture3d + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d + - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl new file mode 100644 index 00000000000..577d7dea749 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.glsl @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if K_CACHE_STORAGE == "buffer": + #define K_CACHE_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_q_projected", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_k_cache", DTYPE, K_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "k_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_q_projected_tile_load.glslh" +#include "sdpa_fp_k_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_attn_weight_tile_store.glslh" + +/* + * Compute attention weights given the q_projected and k_cache tensors. + * q_projected has shape (batches, seq_len, num_q_heads, head_dim) + * k_cache has shape (batches, max_context_len, num_kv_heads, head_dim) + * output has shape (batches, num_q_heads, seq_len, context_len) + * + * This shader also applies scales and masking to the computed attention + * weights. + * + * The scale applied is 1.0 / sqrt(head_dim_length). + * + * The mask applied is a bit more complicated. Imagine you create a square + * matrix of size (input_pos + seq_len, input_pos + seq_len), and then set the + * lower triangular section of the matrix to -inf. Then, slice the matrix along + * the row dimension starting from input_pos to input_pos + seq_len. You end up + * with a partial mask with size (seq_len, input_pos + seq_len). This is the + * mask that is applied to the attention weight. + * + * In the shader, instead of generating the mask, the index of the elment is + * inspected to determine if it would have been masked. Given an element at + * tensor index (n, c, h, w), it would be masked if w < h + input_pos. + * + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int tile_idx_x = int(gl_GlobalInvocationID.x); + const int tile_idx_y = int(gl_GlobalInvocationID.y); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output context_len dim + const int c = tile_idx_x * TILE_N; + const int c4 = div_4(c); + + // idx along the output seq_len dim + const int s = tile_idx_y * TILE_M; + const int s4 = div_4(s); + + // texel size of head_dim, over which the dot product is accumulated + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = k_cache_sizes.y; + // Max context length + const int C = k_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (c >= context_len || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile q_tile; + FPWeightTile w_tile; + + // If the tile is completely inside the mask region, then there is no need to + // compute the output tile. All the elements in the output tile can be set to + // negative infinity. + bool tile_in_mask_region = c > (input_pos + s + (TILE_M - 1)); + if (tile_in_mask_region) { + const VEC4_T negative_infinity_vec = VEC4_T(negative_infinity_val); + set_out_tile_to_vec(out_tile, negative_infinity_vec); + } + // Otherwise, need to actually compute output tile + else { + const bool dont_check_bounds = (S - s) >= TILE_M && + (context_len - c) >= TILE_N; + + if (dont_check_bounds) { + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_no_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } else { + for (int d4 = 0; d4 < D4; d4++) { + load_q_projected_tile_with_checks( + q_tile, + d4, + s, + q_h, + D4, + Q_H, + S); + + load_k_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, q_tile, w_tile); + } + } + + // Apply scale and mask + VEC4_T inv_scale_vec = VEC4_T(inv_scale); + apply_scale_and_mask( + out_tile, + inv_scale_vec, + input_pos, + c, + s); + } + + store_attn_weight_tile_with_checks( + out_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml new file mode 100644 index 00000000000..6aadbbc379e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_attn_weights_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + K_CACHE_STORAGE: texture3d + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d + - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl new file mode 100644 index 00000000000..1fdd803d02b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.glsl @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if V_CACHE_STORAGE == "buffer": + #define V_CACHE_BUFFER + +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M 1 +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +#define NUM_WORKERS_PER_OUT 64 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "float", "inv_scale", "1.0")} + +#include "sdpa_fp_attn_weight_tile_load.glslh" +#include "sdpa_fp_v_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_out_tile_store.glslh" + +shared FPOutTile partial_sums[NUM_WORKERS_PER_OUT]; + +/* + * See the tiled variant of this shader for the implemented behavior. This + * shader is implements an optimization for cases where sequence length is 1; in + * these cases, the matrix multiplication being performed is akin to gemv, which + * benefits from using a co-operative algorithm for reduction. For this shader + * the entire work group co-operates to compute one reduction output. + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int worker_id = int(gl_LocalInvocationID.y); + + const int tile_idx_x = int(gl_GlobalInvocationID.x); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output head_dim dim + const int d = tile_idx_x * TILE_N; + const int d4 = div_4(d); + + // idx along the output seq_len dim. Note that for this shader seq_len will be + // 1. + const int s = 0; + + // texel size of head_dim + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = v_cache_sizes.y; + // Max context length + const int C = v_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (d4 >= D4 || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile attn_weight_tile; + FPWeightTile w_tile; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_up_4(context_len_aligned_down); + + for (int c4 = worker_id; c4 < C4_limit; c4 += NUM_WORKERS_PER_OUT) { + const int c = mul_4(c4); + + load_attn_weight_tile_no_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + // first worker in the work group will handle final texel, which may contain + // padding elements. + if (worker_id == 0) { + for (int c4 = C4_limit; c4 < context_texel_len; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_with_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + } + + partial_sums[worker_id] = out_tile; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result. + for (int i = NUM_WORKERS_PER_OUT / 2; i > 0; i /= 2) { + if (worker_id < i) { + accumulate_out_tile_with_out_tile( + partial_sums[worker_id], partial_sums[worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only the first thread will write out the result + if (worker_id == 0) { + out_tile = partial_sums[0]; + store_sdpa_out_tile_with_checks( + out_tile, + d4, + s, + q_h, + D4, + S, + Q_H); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml similarity index 56% rename from backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index 909b8bfd3a9..ccebf8f7c1c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_texture3d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -4,12 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -flash_attention_texture3d: +sdpa_compute_out_coop: parameter_names_with_default_values: DTYPE: float - STORAGE: texture3d + IO_STORAGE: texture3d + V_CACHE_STORAGE: texture3d + TILE_K4: 1 + TILE_N4: 1 generate_variant_forall: DTYPE: - VALUE: float + - VALUE: half shader_variants: - - NAME: flash_attention_texture3d + - NAME: sdpa_compute_out_coop_texture3d_texture3d + - NAME: sdpa_compute_out_coop_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl new file mode 100644 index 00000000000..fb4eaded826 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.glsl @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if V_CACHE_STORAGE == "buffer": + #define V_CACHE_BUFFER + +#define TILE_M4 ${TILE_M4} +// Equvalent to K4 in matrix multiplication +#define TILE_K4 ${TILE_K4} +// Equvalent to N4 in matrix multiplication +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_attn_weights", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_v_cache", DTYPE, V_CACHE_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "q_projected_sizes")} +${layout_declare_ubo(B, "ivec4", "v_cache_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "sdpa_fp_attn_weight_tile_load.glslh" +#include "sdpa_fp_v_cache_tile_load.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "sdpa_fp_out_tile_store.glslh" + +/* + * Compute SDPA output given the attention weights and v_cache tensors. + * attention weights has shape (batches, num_q_heads, seq_len, context_len) + * v_cache has shape (batches, max_context_len, num_kv_heads, head_dim) + * output has shape (batches, seq_len, num_q_heads, head_dim) + */ + +#extension GL_EXT_debug_printf : enable + +void main() { + const int tile_idx_x = int(gl_GlobalInvocationID.x); + const int tile_idx_y = int(gl_GlobalInvocationID.y); + // idx along output num_q_heads dim + const int q_h = int(gl_GlobalInvocationID.z); + + // idx along the output head_dim dim + const int d = tile_idx_x * TILE_N; + const int d4 = div_4(d); + + // idx along the output seq_len dim + const int s = tile_idx_y * TILE_M; + + // texel size of head_dim + const int D4 = div_up_4(q_projected_sizes.x); + // number of Q heads + const int Q_H = q_projected_sizes.y; + // sequence length + const int S = q_projected_sizes.z; + + // number of K/V heads + const int KV_H = v_cache_sizes.y; + // Max context length + const int C = v_cache_sizes.z; + const int C4 = div_up_4(C); + + int kv_h = q_h; + if (KV_H < Q_H) { + kv_h = q_h / (Q_H / KV_H); + } + + // current context length + const int context_len = input_pos + S; + const int context_texel_len = div_up_4(context_len); + + // bounds check + if (d4 >= D4 || s >= S || q_h >= Q_H) { + return; + } + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile attn_weight_tile; + FPWeightTile w_tile; + + const int context_len_aligned_down = context_len - mod_4(context_len); + const int C4_limit = div_4(context_len_aligned_down); + + for (int c4 = 0; c4 < C4_limit; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_no_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_no_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + for (int c4 = C4_limit; c4 < context_texel_len; c4++) { + const int c = mul_4(c4); + load_attn_weight_tile_with_checks( + attn_weight_tile, + c4, + s, + q_h, + context_texel_len, + S, + Q_H); + + load_v_cache_tile_with_checks( + w_tile, + d4, + c, + kv_h, + D4, + context_len, + C, + KV_H); + + fp_accumulate_with_fp_weight(out_tile, attn_weight_tile, w_tile); + } + + store_sdpa_out_tile_with_checks( + out_tile, + d4, + s, + q_h, + D4, + S, + Q_H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml new file mode 100644 index 00000000000..7fbce29e908 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +sdpa_compute_out_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + V_CACHE_STORAGE: texture3d + TILE_M4: 1 + TILE_K4: 1 + TILE_N4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: sdpa_compute_out_tiled_texture3d_texture3d + - NAME: sdpa_compute_out_tiled_buffer_texture3d + IO_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh new file mode 100644 index 00000000000..12b2292fa45 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - INPUT_BUFFER + */ + +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_attn_weight_c4( + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef INPUT_BUFFER + return t_attn_weights[(q_h * S * C4) + (s * C4) + c4]; +#else + return texelFetch(t_attn_weights, ivec3(c4, s, q_h), 0); +#endif +} + +void load_attn_weight_tile_no_checks( + out FPInputTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + tile.data[s][c4] = + load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } +} + +void load_attn_weight_tile_with_checks( + out FPInputTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + tile.data[s][c4] = + load_attn_weight_c4(c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } else { + tile.data[s][c4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh new file mode 100644 index 00000000000..c64d9af8cfb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_attn_weight_tile_store.glslh @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - OUTPUT_BUFFER + */ + +#ifndef SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH +#define SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +T negative_infinity_val = T(-1.0 / 0.0); + +void store_attn_weight_c4( + const VEC4_T out_texel, + const int c4, + const int s, + const int q_h, + const int C4, + const int S, + const int Q_H) { +#ifdef OUTPUT_BUFFER + t_attn_weights[(q_h * S * C4) + (s * C4) + c4] = out_texel; +#else + imageStore(t_attn_weights, ivec3(c4, s, q_h), out_texel); +#endif +} + +void store_attn_weight_tile_no_checks( + const FPOutTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + store_attn_weight_c4( + tile.data[s][c4], c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } +} + +void store_attn_weight_tile_with_checks( + const FPOutTile tile, + const int c4_start, + const int s_start, + const int q_h, + const int C4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + if (c4_start + c4 < C4 && s_start + s < S) { + store_attn_weight_c4( + tile.data[s][c4], c4_start + c4, s_start + s, q_h, C4, S, Q_H); + } + } + } +} + +void set_out_tile_to_vec(out FPOutTile tile, const VEC4_T vec) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { tile.data[s][c4] = vec; } + } +} + +void apply_scale_and_mask( + inout FPOutTile tile, + const VEC4_T inv_scale_vec, + const int input_pos, + const int c_idx_start, + const int s_idx_start) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int c4 = 0; c4 < TILE_N4; ++c4) { + tile.data[s][c4] = tile.data[s][c4] * inv_scale_vec; + + const int c_base = mul_4(c4); + [[unroll]] for (int c4i = 0; c4i < 4; ++c4i) { + const int c = c_base + c4i; + // Indices of the tile element in the overall output tensor + const int c_idx = c_idx_start + c; + const int s_idx = s_idx_start + s; + if (c_idx > s_idx + input_pos) { + tile.data[s][c4][c4i] = negative_infinity_val; + } + } + } + } +} + +#endif // SDPA_FP_ATTN_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh new file mode 100644 index 00000000000..03132db1348 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_k_cache_tile_load.glslh @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_k_cache + * + * Macro Settings: + * - K_CACHE_BUFFER + */ + +#ifndef SDPA_FP_K_CACHE_TILE_LOAD_GLSLH +#define SDPA_FP_K_CACHE_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +VEC4_T load_k_cache_d4( + const int d4, + const int c, + const int kv_h, + const int D4, + const int C, + const int KV_H) { +#ifdef K_CACHE_BUFFER + return VEC4_T(t_k_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); +#else + return VEC4_T(texelFetch(t_k_cache, ivec3(d4, kv_h, c), 0)); +#endif +} + +void load_k_cache_tile_no_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + bool should_print = d4_start == 0 && c_start == 0 && kv_h == 0; + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + const int c4 = div_4(c); + const int c4i = mod_4(c); + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + VEC4_T d4_row = + load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + + // Transpose in-place + const int d_base = mul_4(d4); + tile.data[d_base][c4][c4i] = d4_row[0]; + tile.data[d_base + 1][c4][c4i] = d4_row[1]; + tile.data[d_base + 2][c4][c4i] = d4_row[2]; + tile.data[d_base + 3][c4][c4i] = d4_row[3]; + } + } +} + +void load_k_cache_tile_with_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + const int c4 = div_4(c); + const int c4i = mod_4(c); + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + VEC4_T d4_row = VEC4_T(0.0); + if (d4_start + d4 < D4 && c_start + c < context_len) { + d4_row = load_k_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } + + // Transpose in-place + const int d_base = mul_4(d4); + tile.data[d_base][c4][c4i] = d4_row[0]; + tile.data[d_base + 1][c4][c4i] = d4_row[1]; + tile.data[d_base + 2][c4][c4i] = d4_row[2]; + tile.data[d_base + 3][c4][c4i] = d4_row[3]; + } + } +} + +#endif // SDPA_FP_K_CACHE_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh new file mode 100644 index 00000000000..17e0988a6a4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_out_tile_store.glslh @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_attn_weights + * + * Macro Settings: + * - OUTPUT_BUFFER + */ + +#ifndef SDPA_FP_OUT_TILE_LOAD_GLSLH +#define SDPA_FP_OUT_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_output_tile.glslh" + +void store_out_d4( + const VEC4_T out_texel, + const int d4, + const int q_h, + const int s, + const int D4, + const int Q_H, + const int S) { +#ifdef OUTPUT_BUFFER + t_output[(s * Q_H * D4) + (q_h * D4) + d4] = out_texel; +#else + imageStore(t_output, ivec3(d4, q_h, s), out_texel); +#endif +} + +void store_sdpa_out_tile_with_checks( + const FPOutTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int S, + const int Q_H) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_N4; ++d4) { + if (d4_start + d4 < D4 && s_start + s < S) { + store_out_d4( + tile.data[s][d4], d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } + } + } +} + +#endif // SDPA_FP_OUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh new file mode 100644 index 00000000000..a304e5019e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_q_projected_tile_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_input + * + * Macro Settings: + * - INPUT_BUFFER + */ + +#ifndef SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH +#define SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_q_projected_d4( + const int d4, + const int q_h, + const int s, + const int D4, + const int Q_H, + const int S) { +#ifdef INPUT_BUFFER + return t_q_projected[(s * Q_H * D4) + (q_h * D4) + d4]; +#else + return texelFetch(t_q_projected, ivec3(d4, q_h, s), 0); +#endif +} + +void load_q_projected_tile_no_checks( + out FPInputTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int Q_H, + const int S) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + tile.data[s][d4] = + load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } + } +} + +void load_q_projected_tile_with_checks( + out FPInputTile tile, + const int d4_start, + const int s_start, + const int q_h, + const int D4, + const int Q_H, + const int S) { + [[unroll]] for (int s = 0; s < TILE_M; ++s) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + if (d4_start + d4 < D4 && s_start + s < S) { + tile.data[s][d4] = + load_q_projected_d4(d4_start + d4, q_h, s_start + s, D4, Q_H, S); + } else { + tile.data[s][d4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_Q_PROJECTED_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh new file mode 100644 index 00000000000..bf94b251c43 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_fp_v_cache_tile_load.glslh @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Assume the following variables are defined in the shader layout: + * - t_v_cache + * + * Macro Settings: + * - V_CACHE_BUFFER + */ + +#ifndef SDPA_FP_V_CACHE_TILE_LOAD_GLSLH +#define SDPA_FP_V_CACHE_TILE_LOAD_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_weight_tile.glslh" + +VEC4_T load_v_cache_d4( + const int d4, + const int c, + const int kv_h, + const int D4, + const int C, + const int KV_H) { +#ifdef V_CACHE_BUFFER + return VEC4_T(t_v_cache[(c * KV_H * D4) + (kv_h * D4) + d4]); +#else + return VEC4_T(texelFetch(t_v_cache, ivec3(d4, kv_h, c), 0)); +#endif +} + +void load_v_cache_tile_no_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + tile.data[c][d4] = + load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } + } +} + +void load_v_cache_tile_with_checks( + out FPWeightTile tile, + const int d4_start, + const int c_start, + const int kv_h, + const int D4, + const int context_len, + const int C, + const int KV_H) { + [[unroll]] for (int c = 0; c < TILE_N; ++c) { + [[unroll]] for (int d4 = 0; d4 < TILE_K4; ++d4) { + if (d4_start + d4 < D4 && c_start + c < context_len) { + tile.data[c][d4] = + load_v_cache_d4(d4_start + d4, c_start + c, kv_h, D4, C, KV_H); + } else { + tile.data[c][d4] = VEC4_T(0.0); + } + } + } +} + +#endif // SDPA_FP_V_CACHE_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl new file mode 100644 index 00000000000..932696fff02 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -0,0 +1,90 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${buffer_scalar_type(DTYPE)} + +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_projected", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "cache_sizes")} +${layout_declare_ubo(B, "ivec4", "projected_sizes")} +${layout_declare_ubo(B, "int", "input_pos")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * t_cache will have sizes of (batches, n_heads, max_context_len, head_dim). + * t_projected will have sizes of (batches, seq_len, n_heads, head_dim). + * + * Note that the cache tensor swaps the order of the n_heads and seq len + * dimensions. This is to faciliate more optimal memory access patterns when + * using the caches to compute matrix multiplications. + * + * The cache update inserts the values of t_projected into t_cache at the index + * specified by input_pos at the seq_len dimension. It is equivalent to calling + + * t_cache = t_cache.slice_scatter( + * t_projected, dim=1, start=input_pos, end=input_pos+seq_len) + * + * Note that this shader is implemented assuming that max_batch_size is 1. + */ + +IN_VEC4_T read_projected_d4( + const int d4, + const int h, + const int s, + const int D4, + const int H, + const int S) { +#ifdef INPUT_BUFFER + return t_projected[(s * H * D4) + (h * D4) + d4]; +#else + return texelFetch(t_projected, ivec3(d4, h, s), 0); +#endif +} + +void write_cache_d4( + const IN_VEC4_T texel, + const int d4, + const int c, + const int h, + const int D4, + const int C, + const int H) { +#ifdef OUTPUT_BUFFER + t_cache[(c * H * D4) + (h * D4) + d4] = texel; +#else + imageStore(t_cache, ivec3(d4, h, c), texel); +#endif +} + +void main() { + const int d4 = int(gl_GlobalInvocationID.x); // idx along the head_dim dim + const int s = int(gl_GlobalInvocationID.y); // idx along the seq_len dim + const int h = int(gl_GlobalInvocationID.z); // idx along the n_heads dim + + const int D4 = div_up_4(projected_sizes.x); + const int S = projected_sizes.z; + const int H = projected_sizes.y; + + if (d4 >= D4 || s >= S || h >= H) { + return; + } + + const int c = s + input_pos; // idx along max_context_len dim + const int C = cache_sizes.y; + + IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S); + write_cache_d4(in_texel, d4, c, h, D4, C, H); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml similarity index 61% rename from backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml rename to backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 795ab906caa..85f4ce090f8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flash_attention_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -4,12 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -flash_attention_buffer: +sdpa_kv_cache_update: parameter_names_with_default_values: DTYPE: float - STORAGE: buffer + INPUT_STORAGE: texture3d + OUTPUT_STORAGE: texture3d generate_variant_forall: DTYPE: + - VALUE: half - VALUE: float shader_variants: - - NAME: flash_attention_buffer + - NAME: sdpa_kv_cache_update_texture3d + - NAME: sdpa_kv_cache_update_buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh new file mode 100644 index 00000000000..da5dcd63b31 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_q_projected_input_tile.glslh @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef SDPA_FP_Q_PROJECTED_TILE_GLSLH +#define SDPA_FP_Q_PROJECTED_TILE_GLSLH + +/* + * Macro Settings: + * - TILE_S + * - TILE_D4 + */ + +#extension GL_EXT_control_flow_attributes : require + +struct FPQProjectedTile { + VEC4_T data[TILE_S][TILE_D4]; +}; + +#ifdef DEBUG_MODE + +void printFPQProjectedTile(const FPQProjectedTile in_tile) { + debugPrintfEXT("input_tile: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + debugPrintfEXT( + " %f, %f, %f, %f, \\n", + in_tile.data[m][k4].x, + in_tile.data[m][k4].y, + in_tile.data[m][k4].z, + in_tile.data[m][k4].w); + } + } +} + +#endif // DEBUG_MODE + +#endif // SDPA_FP_Q_PROJECTED_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 2cc7455cd4a..8edaebd11ff 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -24,6 +24,58 @@ namespace vkcompute { +bool is_single_token(ComputeGraph* graph, const ValueRef& q_projected) { + return graph->size_at(-3, q_projected) == 1; +} + +// +// Resize functions +// + +void resize_compute_attn_weights_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef attn_weights = args.at(0).refs.at(0); + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef input_pos_symint = resize_args.at(0); + + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); + + const int32_t input_pos_val = graph->read_symint(input_pos_symint); + + const uint32_t context_len = seq_len + input_pos_val; + + std::vector out_sizes = { + 1, // batch + num_q_heads, + seq_len, + utils::align_up_4(context_len)}; + + graph->virtual_resize(attn_weights, out_sizes); +} + +void resize_sdpa_attn_weights_softmax_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef attn_weights_softmax = args.at(0).refs.at(0); + const ValueRef attn_weights = args.at(1).refs.at(0); + + graph->virtual_resize(attn_weights_softmax, graph->sizes_of(attn_weights)); +} + +void resize_sdpa_compute_out_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef q_projected = resize_args.at(0); + + graph->virtual_resize(out, graph->sizes_of(q_projected)); +} + void resize_sdpa_out( ComputeGraph* graph, const std::vector& args, @@ -36,195 +88,207 @@ void resize_sdpa_out( graph->virtual_resize(out, graph->sizes_of(q_projected)); } -void resize_flash_attention_out( +// +// Shader dispatch pick functions +// + +utils::uvec3 kv_cache_update_global_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { + (void)shader; (void)resize_args; - // Find the output tensor in the args - it's the first tensor in the first - // ArgGroup - const ValueRef out = args.at(0).refs.at(0); - const ValueRef q_projected = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(q_projected)); + const ValueRef projected = args.at(1).refs.at(0); + + const uint32_t head_dim_size = graph->size_at(-1, projected); + const uint32_t num_heads = graph->size_at(-2, projected); + const uint32_t seq_len = graph->size_at(-3, projected); + + return {utils::div_up_4(head_dim_size), seq_len, num_heads}; } -utils::uvec3 flash_attention_global_wg_size( +utils::uvec3 attn_weight_scale_and_mask_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { (void)shader; + (void)resize_args; - const ValueRef q_projected = resize_args.at(0); - const ValueRef block_size_r = resize_args.at(1); + const ValueRef attn_weight = args.at(0).refs.at(0); - // Get tensor dimensions - PyTorch format is [B, N, H, D] - // But Vulkan uses negative indexing: -4=B, -3=N, -2=H, -1=D - const int32_t B = graph->size_at(-4, q_projected); // batch - const int32_t N = graph->size_at(-3, q_projected); // sequence length - const int32_t H = graph->size_at(-2, q_projected); // num heads - const int32_t Br = - static_cast(graph->extract_scalar(block_size_r)); + if (graph->is_buffer_storage(attn_weight)) { + return { + graph->size_at(-1, attn_weight), + graph->size_at(-2, attn_weight), + graph->size_at(-3, attn_weight), + }; + } else { + return graph->logical_limits_of(attn_weight); + } +} + +vkapi::ShaderInfo pick_sdpa_compute_attn_weights_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef k_cache = args.at(1).refs.at(1); - // Calculate number of row blocks - const int32_t Tr = (N + Br - 1) / Br; + const bool is_gemv = is_single_token(graph, q_projected); - return {static_cast(B * H * Tr), 1, 1}; + std::string shader_name = "sdpa_compute_attn_weights"; + if (is_gemv) { + shader_name += "_coop"; + } else { + shader_name += "_tiled"; + } + + add_storage_type_suffix(shader_name, graph->storage_type_of(q_projected)); + add_storage_type_suffix(shader_name, graph->storage_type_of(k_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(q_projected)); + + return VK_KERNEL_FROM_STR(shader_name); } -void flash_attention_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef q_projected = args[arg_idx++]; - const ValueRef k_cache = args[arg_idx++]; - const ValueRef v_cache = args[arg_idx++]; - const ValueRef input_pos_symint = args[arg_idx++]; - const ValueRef attn_mask = args[arg_idx++]; - const ValueRef dropout_p = args[arg_idx++]; - const ValueRef is_causal = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; +utils::uvec3 pick_sdpa_compute_attn_weights_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = args.at(1).refs.at(0); + const ValueRef input_pos_symint = resize_args.at(0); - const ValueRef out = args[arg_idx++]; + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - // Extract input_pos value for causal masking - const int32_t input_pos_val = graph.read_symint(input_pos_symint); + const int32_t input_pos_val = graph->read_symint(input_pos_symint); - const ValueRef k_cache_tensor = k_cache; - const ValueRef v_cache_tensor = v_cache; + const uint32_t context_len = seq_len + input_pos_val; - // Validation checks - re-enable with correct indexing - VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); // batch size = 1 - VK_CHECK_COND(graph.size_at(-4, k_cache_tensor) == 1); - VK_CHECK_COND(graph.size_at(-4, v_cache_tensor) == 1); - VK_CHECK_COND( - graph.sizes_of(k_cache_tensor) == graph.sizes_of(v_cache_tensor)); - VK_CHECK_COND( - graph.size_at(-1, q_projected) == - graph.size_at(-1, k_cache_tensor)); // head_dim must match - VK_CHECK_COND( - graph.val_is_none(dropout_p) || - graph.extract_scalar(dropout_p) == 0); - VK_CHECK_COND(graph.val_is_none(scale)); - VK_CHECK_COND( - graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); - VK_CHECK_COND(graph.val_is_none(attn_mask)); + const uint32_t N4 = utils::div_up_4(context_len); + const uint32_t M4 = utils::div_up_4(seq_len); + + return {N4, M4, num_q_heads}; +} + +utils::uvec3 pick_sdpa_compute_attn_weights_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; - if (graph.is_buffer_storage(q_projected)) { - VK_CHECK_COND(graph.is_buffer_storage(k_cache_tensor)); - VK_CHECK_COND(graph.is_buffer_storage(v_cache_tensor)); - VK_CHECK_COND(graph.is_buffer_storage(out)); + if (use_coop_algorithm) { + return {1, 64, 1}; + } else { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } +} - // Calculate scale factor - const int32_t head_dim_size = graph.size_at(-1, q_projected); - const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); +utils::uvec3 pick_sdpa_attn_weights_softmax_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef q_projected = resize_args.at(0); - // Get number of heads for multi-query attention support - const int32_t num_heads = graph.size_at(-2, q_projected); - const int32_t num_kv_heads = graph.size_at(-2, k_cache_tensor); + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - const int32_t block_size_r = 32; // Row block size - const int32_t block_size_c = 32; // Column block size + return {1, seq_len, num_q_heads}; +} - // l and m have shape [B, H, N] - std::vector lm_sizes = { - graph.size_at(-4, q_projected), // B (batch) - graph.size_at(-2, q_projected), // H (num heads) - graph.size_at(-3, q_projected) // N (sequence length) - }; +utils::uvec3 pick_sdpa_attn_weights_softmax_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return {64, 1, 1}; +} - // t_l stores row-wise normalization sums for softmax computation - // t_m stores row-wise maximum values for numerical stability in softmax - TmpTensor t_l(&graph, lm_sizes, vkapi::kFloat, graph.storage_type_of(out)); - TmpTensor t_m(&graph, lm_sizes, vkapi::kFloat, graph.storage_type_of(out)); +vkapi::ShaderInfo pick_sdpa_compute_out_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef v_cache = args.at(1).refs.at(1); - // Choose kernel name based on storage type - std::string kernel_name = "flash_attention"; - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); + const ValueRef q_projected = resize_args.at(0); - vkapi::ParamsBindList param_ubos = { - graph.sizes_ubo(q_projected), // Q_sizes - graph.sizes_ubo(k_cache_tensor), // K_sizes - graph.sizes_ubo(v_cache_tensor), // V_sizes - graph.sizes_ubo(out), // O_sizes - graph.sizes_ubo(t_l), // l_sizes - graph.sizes_ubo(t_m), // m_sizes - graph.create_params_buffer(scale_val), // scale - graph.create_params_buffer(block_size_r), // block_size_r - graph.create_params_buffer(block_size_c), // block_size_c - graph.create_params_buffer(input_pos_val), // input_pos - graph.create_params_buffer(num_heads), // num_heads - graph.create_params_buffer(num_kv_heads) // num_kv_heads - }; - - // Create block size references for dispatch calculation - const ValueRef block_size_r_ref = - graph.add_scalar(static_cast(block_size_r)); + const bool is_gemv = is_single_token(graph, q_projected); - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - flash_attention_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - { - {{out, t_l, t_m}, vkapi::kReadWrite}, - {{q_projected, k_cache_tensor, v_cache_tensor}, vkapi::kRead}, - }, - // Shader param buffers - param_ubos, - // Push Constants - {}, - // Specialization Constants - {}, - // Resize Args - {q_projected, block_size_r_ref}, - // Resizing Logic - resize_flash_attention_out)); + std::string shader_name = "sdpa_compute_out"; + if (is_gemv) { + shader_name += "_coop"; + } else { + shader_name += "_tiled"; + } + + add_storage_type_suffix(shader_name, graph->storage_type_of(out)); + add_storage_type_suffix(shader_name, graph->storage_type_of(v_cache)); + add_dtype_suffix(shader_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(shader_name); } -utils::uvec3 kv_cache_update_global_wg_size( +utils::uvec3 pick_sdpa_compute_out_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { - (void)shader; - (void)resize_args; + const ValueRef q_projected = resize_args.at(0); - const ValueRef cache = args.at(0).refs.at(0); - const ValueRef projected = args.at(1).refs.at(0); + const uint32_t head_dim = graph->size_at(-1, q_projected); + const uint32_t num_q_heads = graph->size_at(-2, q_projected); + const uint32_t seq_len = graph->size_at(-3, q_projected); - if (graph->is_buffer_storage(cache)) { - return graph->create_global_wg_size(projected); + const uint32_t N4 = utils::div_up_4(head_dim); + const uint32_t M4 = utils::div_up_4(seq_len); + + return {N4, M4, num_q_heads}; +} + +utils::uvec3 pick_sdpa_compute_out_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {1, 64, 1}; } else { - return graph->logical_limits_of(projected); + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); } } -void add_kv_cache_update_node( +// +// Dispatch nodes +// + +void add_sdpa_kv_cache_update_node( ComputeGraph& graph, const ValueRef input_pos_symint, const ValueRef projected, const ValueRef cache) { - std::string kernel_name("kv_cache_update"); + std::string kernel_name("sdpa_kv_cache_update"); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(cache)) { - param_ubos = { - graph.numel_ubo(projected), - graph.strides_ubo(cache), - graph.get_or_create_int_param_buffer(input_pos_symint)}; - } else { - param_ubos = { - graph.logical_limits_ubo(projected), - graph.get_or_create_int_param_buffer(input_pos_symint)}; - } + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(cache), + graph.sizes_ubo(projected), + graph.get_or_create_int_param_buffer(input_pos_symint)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -240,121 +304,113 @@ void add_kv_cache_update_node( // Specialization Constants {}, // Resize Args - {}, + {input_pos_symint}, // Resizing Logic nullptr)); } -utils::uvec3 attn_weight_scale_and_mask_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef attn_weight = args.at(0).refs.at(0); - - if (graph->is_buffer_storage(attn_weight)) { - return { - graph->size_at(-1, attn_weight), - graph->size_at(-2, attn_weight), - graph->size_at(-3, attn_weight), - }; - } else { - return graph->logical_limits_of(attn_weight); - } -} - -void add_attn_weight_scale_and_mask_node( +void add_sdpa_compute_attn_weights_node( ComputeGraph& graph, - const ValueRef input_pos_symint, const ValueRef q_projected, - const ValueRef attn_weight) { - std::string kernel_name("sdpa_attn_weight_scale_and_mask"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(attn_weight)); - add_dtype_suffix(kernel_name, graph.dtype_of(attn_weight)); - + const ValueRef k_cache, + const ValueRef input_pos_symint, + const ValueRef attn_weights) { const int32_t head_dim_size = graph.size_at(-1, q_projected); const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(attn_weight)) { - param_ubos = { - graph.sizes_ubo(attn_weight), - graph.strides_ubo(attn_weight), - graph.create_params_buffer(scale_val)}; - } else { - param_ubos = { - graph.logical_limits_ubo(attn_weight), - graph.get_or_create_int_param_buffer(input_pos_symint), - graph.create_params_buffer(scale_val)}; - } + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.sizes_ubo(k_cache), + graph.get_or_create_int_param_buffer(input_pos_symint)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - attn_weight_scale_and_mask_global_wg_size, - default_pick_local_wg_size, + pick_sdpa_compute_attn_weights_shader, + pick_sdpa_compute_attn_weights_global_wg_size, + pick_sdpa_compute_attn_weights_local_wg_size, // Inputs and Outputs - {{attn_weight, vkapi::kReadWrite}}, + {{attn_weights, vkapi::kWrite}, {{q_projected, k_cache}, vkapi::kRead}}, // Shader param buffers param_ubos, // Push Constants {}, // Specialization Constants - {}, + {scale_val}, // Resize Args - {}, + {input_pos_symint}, // Resizing Logic - nullptr)); + resize_compute_attn_weights_node)); } -std::vector get_cache_slice_sizes( +void add_sdpa_attn_weights_softmax_node( ComputeGraph& graph, - ValueRef cache, - ValueRef input_pos_symint, - ValueRef q_projected) { - std::vector slice_sizes = graph.sizes_of(cache); - - // Cache slicing will always be in the channels dim - const int32_t input_pos_val = graph.read_symint(input_pos_symint); - const int64_t q_seq_len = graph.size_at(1, q_projected); - slice_sizes.at(1) = input_pos_val + q_seq_len; - return slice_sizes; -} + const ValueRef attn_weights, + const ValueRef q_projected, + const ValueRef input_pos_symint, + const ValueRef attn_weights_softmax) { + std::string shader_name = "sdpa_attn_weights_softmax"; + add_storage_type_suffix( + shader_name, graph.storage_type_of(attn_weights_softmax)); + add_dtype_suffix(shader_name, graph.dtype_of(attn_weights_softmax)); -void resize_cache_slice_view_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)args; - std::vector slice_sizes = get_cache_slice_sizes( - *graph, extra_args[0], extra_args[1], extra_args[2]); + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.get_or_create_int_param_buffer(input_pos_symint)}; - graph->virtual_resize(extra_args[3], slice_sizes); + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(shader_name), + pick_sdpa_attn_weights_softmax_global_wg_size, + pick_sdpa_attn_weights_softmax_local_wg_size, + // Inputs and Outputs + {{attn_weights_softmax, vkapi::kWrite}, {attn_weights, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {q_projected, input_pos_symint}, + // Resizing Logic + resize_sdpa_attn_weights_softmax_node)); } -void add_cache_slice_view_node( +void add_sdpa_compute_out_node( ComputeGraph& graph, - ValueRef cache, - ValueRef input_pos_symint, - ValueRef q_projected, - ValueRef cache_sliced, - const int64_t max_seq_len) { - std::vector slice_sizes = - get_cache_slice_sizes(graph, cache, input_pos_symint, q_projected); - // Initialize the slice to the maximum possible size to start - slice_sizes.at(1) = max_seq_len; - - graph.virtual_resize(cache_sliced, slice_sizes); - - graph.execute_nodes().emplace_back(new ExecuteNode( - resize_cache_slice_view_node, - {cache, input_pos_symint, q_projected, cache_sliced})); + const ValueRef attn_weights_softmax, + const ValueRef v_cache, + const ValueRef q_projected, + const ValueRef input_pos_symint, + const ValueRef out) { + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(q_projected), + graph.sizes_ubo(v_cache), + graph.get_or_create_int_param_buffer(input_pos_symint)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_sdpa_compute_out_shader, + pick_sdpa_compute_out_global_wg_size, + pick_sdpa_compute_out_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{attn_weights_softmax, v_cache}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {q_projected, input_pos_symint}, + // Resizing Logic + resize_sdpa_compute_out_node)); } +// +// High level operator impl +// + void update_cache_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef value = args[arg_idx++]; @@ -372,7 +428,7 @@ void update_cache_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND( graph.size_at(-2, value) == graph.size_at(-2, cache)); - add_kv_cache_update_node(graph, input_pos_symint, value, cache); + add_sdpa_kv_cache_update_node(graph, input_pos_symint, value, cache); } void sdpa_impl(ComputeGraph& graph, const std::vector& args) { @@ -413,105 +469,39 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); VK_CHECK_COND(graph.val_is_none(attn_mask)); - const int32_t max_seq_len = graph.size_at(1, k_cache); + const int64_t num_q_heads = graph.size_at(-2, q_projected); + const int64_t max_seq_len = graph.size_at(-3, q_projected); - // Slice caches from 0 to input_pos + sequence_len - const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache); - const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache); - add_cache_slice_view_node( - graph, - k_cache, - input_pos_symint, - q_projected, - k_cache_sliced, - max_seq_len); - add_cache_slice_view_node( - graph, - v_cache, - input_pos_symint, - q_projected, - v_cache_sliced, - max_seq_len); - - // Scalar values for various dims - const ValueRef channels = graph.add_scalar(1); - const ValueRef height = graph.add_scalar(2); - const ValueRef width = graph.add_scalar(3); - - // Repeat interleave - const int64_t num_heads = graph.size_at(2, q_projected); - const int64_t num_kv_heads = graph.size_at(2, k_cache); - - const ValueRef num_repeats = - graph.add_scalar(num_heads / num_kv_heads); - - std::vector cache_slice_repeated_sizes(graph.sizes_of(q_projected)); - cache_slice_repeated_sizes.at(1) = max_seq_len; - - TmpTensor k_cache_sliced_repeated( - &graph, cache_slice_repeated_sizes, graph.dtype_of(k_cache_sliced)); - TmpTensor v_cache_sliced_repeated( - &graph, cache_slice_repeated_sizes, graph.dtype_of(v_cache_sliced)); - - add_repeat_interleave_node( - graph, k_cache_sliced, num_repeats, height, k_cache_sliced_repeated); - add_repeat_interleave_node( - graph, v_cache_sliced, num_repeats, height, v_cache_sliced_repeated); - - // Transpose sequence and head dims - const ValueRef q_transposed = graph.add_tensor_view(q_projected); - const ValueRef k_transposed = graph.add_tensor_view(k_cache_sliced_repeated); - const ValueRef v_transposed = graph.add_tensor_view(v_cache_sliced_repeated); - - add_transpose_view_node(graph, q_projected, channels, height, q_transposed); - add_transpose_view_node( - graph, k_cache_sliced_repeated, channels, height, k_transposed); - add_transpose_view_node( - graph, v_cache_sliced_repeated, channels, height, v_transposed); - - // Transpose K again to prepare for matmul - const ValueRef k_transposed_2 = graph.add_tensor_view(k_transposed); - add_transpose_view_node(graph, k_transposed, height, width, k_transposed_2); - - // Initialize attn_weight to the maximum possible size - std::vector attn_weight_full_sizes = graph.sizes_of(q_transposed); - attn_weight_full_sizes.at(2) = max_seq_len; - attn_weight_full_sizes.at(3) = max_seq_len; - TmpTensor attn_weight( - &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed)); - - // Resize attn_weight to the correct dim - std::vector attn_weight_sizes = attn_weight_full_sizes; - attn_weight_sizes.at(2) = graph.size_at(2, q_transposed); - attn_weight_sizes.at(3) = graph.size_at(2, k_transposed); - graph.virtual_resize(attn_weight, attn_weight_sizes); - - // Calculate attention weight, which is a matmul of Q and K - const ValueRef mat2_is_transposed = graph.add_scalar(false); - add_matmul_node( - graph, q_transposed, k_transposed_2, attn_weight, mat2_is_transposed); - - // Apply scale and mask to the attention weight - add_attn_weight_scale_and_mask_node( - graph, input_pos_symint, q_projected, attn_weight); - - TmpTensor attn_weight_softmax( - &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed)); - graph.virtual_resize(attn_weight_softmax, attn_weight_sizes); - add_softmax_node(graph, attn_weight, width, attn_weight_softmax, false); - - // Calculate final output - const ValueRef out_transposed = graph.add_tensor_view(out); - add_transpose_view_node(graph, out, channels, height, out_transposed); - add_matmul_node( - graph, - attn_weight_softmax, - v_transposed, - out_transposed, - mat2_is_transposed); + const int64_t max_context_len = graph.size_at(-3, k_cache); + + std::vector attn_weight_full_sizes = { + 1, // batch + num_q_heads, + max_seq_len, + max_context_len}; + + TmpTensor attn_weights( + &graph, + attn_weight_full_sizes, + graph.dtype_of(q_projected), + graph.storage_type_of(q_projected), + utils::kWidthPacked); + + TmpTensor attn_weights_softmax( + &graph, + attn_weight_full_sizes, + graph.dtype_of(q_projected), + graph.storage_type_of(q_projected), + utils::kWidthPacked); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, attn_weights); + + add_sdpa_attn_weights_softmax_node( + graph, attn_weights, q_projected, input_pos_symint, attn_weights_softmax); - graph.execute_nodes().emplace_back( - new ExecuteNode(resize_sdpa_out, {q_projected, out})); + add_sdpa_compute_out_node( + graph, attn_weights_softmax, v_cache, q_projected, input_pos_symint, out); } void sdpa_with_kv_cache_impl( @@ -535,10 +525,10 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = - prepack_standard_like(graph, k_cache_data, q_projected); - const ValueRef v_cache = - prepack_standard_like(graph, v_cache_data, q_projected); + const ValueRef k_cache = prepack_standard( + graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); + const ValueRef v_cache = prepack_standard( + graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -560,7 +550,6 @@ REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); - VK_REGISTER_OP(llama.flash_attention.default, flash_attention_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/CMakeLists.txt b/backends/vulkan/test/op_tests/CMakeLists.txt index 07a13c3f260..5e8991f8e50 100644 --- a/backends/vulkan/test/op_tests/CMakeLists.txt +++ b/backends/vulkan/test/op_tests/CMakeLists.txt @@ -47,7 +47,7 @@ find_library(LIB_C10 c10 HINTS ${TORCH_INSTALL_PREFIX}/lib) # Third party include paths -set(VULKAN_THIRD_PARTY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../third-party) +set(VULKAN_THIRD_PARTY_PATH ${EXECUTORCH_ROOT}/backends/vulkan/third-party) set(GTEST_INCLUDE_PATH ${EXECUTORCH_ROOT}/third-party/googletest/googletest/include diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index e4b3f662c04..a94e68a53af 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -215,16 +215,13 @@ at::Tensor sdpa_reference_impl( void test_reference_sdpa( const int start_input_pos, const int sequence_len, - const int embedding_dim, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, const int max_seq_len, at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; - // K and V caches. Need an extra set for the reference implementation - at::Tensor k_cache = at::zeros( {batch_size, max_seq_len, num_kv_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -265,19 +262,23 @@ void test_reference_sdpa( void test_vulkan_sdpa( const int start_input_pos, - const int base_sequence_len, - const int embedding_dim, + const std::vector& sequence_lens, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, - const int max_seq_len, - const bool dynamic_seq_len = true, + vkcompute::utils::StorageType storage_type, at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; + // compute the max sequence length + int max_seq_len = start_input_pos; + for (int i = 0; i < sequence_lens.size(); ++i) { + max_seq_len += sequence_lens[i]; + } + // Add some extra space to the max sequence length + max_seq_len += 128; - const int init_seq_len = dynamic_seq_len ? max_seq_len : base_sequence_len; + const int init_seq_len = max_seq_len; // K and V caches - at::Tensor k_cache = at::zeros( {batch_size, max_seq_len, num_kv_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -300,7 +301,6 @@ void test_vulkan_sdpa( using namespace vkcompute; GraphConfig config; - config.set_storage_type_override(utils::kTexture3D); ComputeGraph graph(config); // "Data" variant for vulkan initialization @@ -319,7 +319,7 @@ void test_vulkan_sdpa( #define MAKE_INPUT_FOR(x) \ IOValueRef r_##x = graph.add_input_tensor( \ - x.sizes().vec(), from_at_scalartype(x.scalar_type())); + x.sizes().vec(), from_at_scalartype(x.scalar_type()), storage_type); MAKE_INPUT_FOR(q); MAKE_INPUT_FOR(k); @@ -328,7 +328,7 @@ void test_vulkan_sdpa( const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); const ValueRef r_out = graph.add_tensor( - out.sizes().vec(), from_at_scalartype(out.scalar_type())); + out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); VK_GET_OP_FN("sdpa_with_kv_cache.default") (graph, @@ -365,10 +365,10 @@ void test_vulkan_sdpa( graph.copy_from_staging( \ staging_##x, vk_##x.mutable_data_ptr(), vk_##x.numel()); - int seq_len = base_sequence_len; - for (int i = 0, input_pos = start_input_pos; - input_pos + seq_len < max_seq_len; - input_pos += seq_len, i++) { + torch::manual_seed(0); + + int input_pos = start_input_pos; + for (auto seq_len : sequence_lens) { q = at::rand( {batch_size, seq_len, num_heads, head_dim}, at::device(at::kCPU).dtype(dtype)); @@ -398,6 +398,46 @@ void test_vulkan_sdpa( const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { + // Print only differing tensor elements side by side for easier comparison + auto ref_flat = reference_out.flatten(); + auto vk_flat = vk_out.flatten(); + auto numel = ref_flat.numel(); + std::cout << "reference_out\tvk_out\tindex" << std::endl; + int first_diff_idx = -1; + auto sizes = reference_out.sizes(); + int d0 = sizes[0], d1 = sizes[1], d2 = sizes[2], d3 = sizes[3]; + for (int i = 0; i < numel; ++i) { + if (std::abs(ref_flat[i].item() - vk_flat[i].item()) > + 1e-4) { + // Compute 4-D index from flat index + int i0 = i / (d1 * d2 * d3); + int rem0 = i % (d1 * d2 * d3); + int i1 = rem0 / (d2 * d3); + int rem1 = rem0 % (d2 * d3); + int i2 = rem1 / d3; + int i3 = rem1 % d3; + std::cout << ref_flat[i].item() << "\t" << vk_flat[i].item() << "\t[" + << i0 << ", " << i1 << ", " << i2 << ", " << i3 << "]" + << std::endl; + if (first_diff_idx == -1) { + first_diff_idx = i; + } + break; + } + } + if (first_diff_idx != -1) { + // Compute 4-D index from flat index + int i0 = first_diff_idx / (d1 * d2 * d3); + int rem0 = first_diff_idx % (d1 * d2 * d3); + int i1 = rem0 / (d2 * d3); + int rem1 = rem0 % (d2 * d3); + int i2 = rem1 / d3; + int i3 = rem1 % d3; + std::cout << "First difference at flat index " << first_diff_idx + << " which is tensor index [" << i0 << ", " << i1 << ", " + << i2 << ", " << i3 << "]" << std::endl; + } + at::Tensor diffs = at::abs(reference_out - vk_out); std::cout << "Failed at input_pos " << input_pos << " with seq_len " @@ -414,426 +454,65 @@ void test_vulkan_sdpa( } ASSERT_TRUE(output_correct); - if (dynamic_seq_len) { - seq_len = base_sequence_len + (i % 3); - } + input_pos += seq_len; } } -TEST(VulkanSDPATest, test_sdpa_op_small_params) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 7; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len, - false); -} - -TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 12; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_reference_impl) { - const int starting_input_pos = 0; - const int base_sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_reference_sdpa( - starting_input_pos, - base_sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -void test_vulkan_flash_attention_impl( - const int start_input_pos, - const int sequence_len, - const int embedding_dim, - const int num_heads, - const int num_kv_heads, - const int batch_size, - const int max_seq_len, - vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { - const int head_dim = embedding_dim / num_heads; - - at::Tensor k_cache = at::zeros( - {batch_size, max_seq_len, num_kv_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor v_cache = at::zeros_like(k_cache); - - at::Tensor q = at::rand( - {batch_size, sequence_len, num_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor k = at::rand( - {batch_size, sequence_len, num_kv_heads, head_dim}, - at::device(at::kCPU).dtype(dtype)); - at::Tensor v = at::rand_like(k); - - // Get reference output using existing SDPA - at::Tensor reference_out = sdpa_reference_impl( - q, - k, - v, - k_cache, - v_cache, - start_input_pos, - sequence_len, - {}, - 0.0, - true, - {}); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(storage_type); - ComputeGraph graph(config); - - // Create input references - IOValueRef r_q = graph.add_input_tensor( - q.sizes().vec(), from_at_scalartype(q.scalar_type())); - IOValueRef r_k = graph.add_input_tensor( - k.sizes().vec(), from_at_scalartype(k.scalar_type())); - IOValueRef r_v = graph.add_input_tensor( - v.sizes().vec(), from_at_scalartype(v.scalar_type())); - - // Create cache tensors (these would be updated by cache update operations in - // practice) - ValueRef r_k_cache = graph.add_tensorref( - k_cache.sizes().vec(), - from_at_scalartype(k_cache.scalar_type()), - k_cache.const_data_ptr()); - ValueRef r_v_cache = graph.add_tensorref( - v_cache.sizes().vec(), - from_at_scalartype(v_cache.scalar_type()), - v_cache.const_data_ptr()); - - const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos); - const ValueRef r_out = - graph.add_tensor(q.sizes().vec(), from_at_scalartype(q.scalar_type())); - - // Call Flash Attention implementation - VK_GET_OP_FN("llama.flash_attention.default") - (graph, - { - r_q.value, - r_k.value, // Use actual K tensor, not cache - r_v.value, // Use actual V tensor, not cache - r_input_pos_symint, - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy inputs and run - graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel()); - graph.copy_into_staging(r_k.staging, k.const_data_ptr(), k.numel()); - graph.copy_into_staging(r_v.staging, v.const_data_ptr(), v.numel()); - - graph.execute(); - - // Extract output - at::Tensor vk_out = at::zeros_like(q).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare results - const bool output_correct = at::allclose(reference_out, vk_out, 1e-3, 1e-3); - - if (!output_correct) { - at::Tensor diffs = at::abs(reference_out - vk_out); - std::cout << "Maximum difference: " << at::max(diffs).item() << std::endl; - std::cout << "Maximum value observed: " - << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item() - << std::endl; - } - ASSERT_TRUE(output_correct); -} - -void test_vulkan_flash_attention( +void test_vulkan_sdpa( const int start_input_pos, - const int sequence_len, - const int embedding_dim, + const std::vector& sequence_lens, + const int head_dim, const int num_heads, const int num_kv_heads, const int batch_size, - const int max_seq_len, at::ScalarType dtype = at::kFloat) { - test_vulkan_flash_attention_impl( + // Test texture + test_vulkan_sdpa( start_input_pos, - sequence_len, - embedding_dim, + sequence_lens, + head_dim, num_heads, num_kv_heads, batch_size, - max_seq_len, - vkcompute::utils::kBuffer, + vkcompute::utils::kTexture3D, dtype); - test_vulkan_flash_attention_impl( + // Test buffer + test_vulkan_sdpa( start_input_pos, - sequence_len, - embedding_dim, + sequence_lens, + head_dim, num_heads, num_kv_heads, batch_size, - max_seq_len, - vkcompute::utils::kTexture3D, + vkcompute::utils::kBuffer, dtype); } -// Flash Attention Tests (both Buffer and Texture) -TEST(VulkanSDPATest, test_flash_attention_small_params) { - const int starting_input_pos = 0; - const int sequence_len = 2; - const int embedding_dim = 4; - const int num_heads = 2; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 4; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_multi_tile) { - const int starting_input_pos = 0; - const int sequence_len = 48; - const int embedding_dim = 32; - const int num_heads = 2; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 64; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_op_small_params) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 18; - const int num_heads = 6; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 7; +TEST(VulkanSDPATest, test_sdpa_op_small_params) { + const int base_sequence_len = 3; + const int num_heads = 8; + const int head_dim = 4; + const int num_kv_heads = 4; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); + test_vulkan_sdpa( + 0, {3, 1, 1, 5, 1, 1, 2}, head_dim, num_heads, num_kv_heads, 1); } -TEST(VulkanSDPATest, test_flash_attention_op_small_params_dynamic) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 18; +TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) { + const int base_sequence_len = 3; + const int head_dim = 8; const int num_heads = 6; const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 12; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); + test_vulkan_sdpa(0, {3, 1, 1, 5, 1, 1}, head_dim, num_heads, num_kv_heads, 1); } -TEST(VulkanSDPATest, test_flash_attention_op_llama3_params) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_op_llama3_params_dynamic) { - const int starting_input_pos = 0; - const int embedding_dim = 2048; - const int num_heads = 32; - const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - // Test with different sequence lengths - std::vector sequence_lengths = {1, 3, 5, 7, 16, 32}; - - for (int seq_len : sequence_lengths) { - if (seq_len < max_seq_len) { - test_vulkan_flash_attention( - starting_input_pos, - seq_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); - } - } -} - -TEST(VulkanSDPATest, test_flash_attention_reference_impl) { - const int starting_input_pos = 0; - const int sequence_len = 3; - const int embedding_dim = 2048; - const int num_heads = 32; +TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) { + const int head_dim = 128; + const int num_heads = 24; const int num_kv_heads = 8; - const int batch_size = 1; - const int max_seq_len = 128; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_reference_impl_small) { - const int starting_input_pos = 0; - const int sequence_len = 2; - const int embedding_dim = 32; - const int num_heads = 4; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 16; - - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_vec4_alignment) { - const int starting_input_pos = 0; - const int sequence_len = 8; - const int embedding_dim = 64; - const int num_heads = 4; - const int num_kv_heads = 2; - const int batch_size = 1; - const int max_seq_len = 16; - test_vulkan_flash_attention( - starting_input_pos, - sequence_len, - embedding_dim, - num_heads, - num_kv_heads, - batch_size, - max_seq_len); -} - -TEST(VulkanSDPATest, test_flash_attention_edge_cases) { - // Test with single head (no multi-query complexity) - test_vulkan_flash_attention(0, 1, 8, 1, 1, 1, 4); - - // Test with equal heads (no multi-query complexity) - test_vulkan_flash_attention(0, 2, 16, 4, 4, 1, 8); - - // Test with large head dimension - test_vulkan_flash_attention(0, 2, 128, 2, 1, 1, 8); - - // Test with sequence length that exactly matches block size (32) - test_vulkan_flash_attention(0, 32, 64, 2, 1, 1, 64); - - // Test with sequence length slightly larger than block size - test_vulkan_flash_attention( - 0, 33, 68, 2, 1, 1, 64); // 68 = 4*17, good for vec4 + test_vulkan_sdpa( + 0, {111, 1, 1, 1, 57, 1, 1}, head_dim, num_heads, num_kv_heads, 1); } diff --git a/backends/vulkan/test/scripts/test_op.sh b/backends/vulkan/test/scripts/test_op.sh index 36920cb73cc..1ec07b7f75f 100755 --- a/backends/vulkan/test/scripts/test_op.sh +++ b/backends/vulkan/test/scripts/test_op.sh @@ -141,6 +141,8 @@ build_core_libraries() { -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ + -DEXECUTORCH_BUILD_KERNELS_LLM_AOT=ON \ -DEXECUTORCH_BUILD_DEVTOOLS=ON \ -DEXECUTORCH_BUILD_VULKAN=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ @@ -152,39 +154,45 @@ build_core_libraries() { build_operator_tests() { echo "Building Vulkan operator tests..." - # Check if TORCH_OPS_YAML_PATH is set, if not use default - if [[ -z "${TORCH_OPS_YAML_PATH:-}" ]]; then - TORCH_OPS_YAML_PATH="$HOME/Github/pytorch/aten/src/ATen/native" - echo "Using default TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" - fi + # Prepare CMAKE arguments + CMAKE_ARGS=( + "backends/vulkan/test/op_tests" + "-DCMAKE_INSTALL_PREFIX=cmake-out" + "-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE" + "-DCMAKE_CXX_STANDARD=17" + ) - # Verify that TORCH_OPS_YAML_PATH exists - if [[ ! -d "$TORCH_OPS_YAML_PATH" ]]; then - echo "Error: TORCH_OPS_YAML_PATH directory does not exist: $TORCH_OPS_YAML_PATH" - echo "Please set TORCH_OPS_YAML_PATH to a valid PyTorch native operations directory" - echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" - exit 1 - fi + # Check if TORCH_OPS_YAML_PATH is set + if [[ -n "${TORCH_OPS_YAML_PATH:-}" ]]; then + # Verify that TORCH_OPS_YAML_PATH exists + if [[ ! -d "$TORCH_OPS_YAML_PATH" ]]; then + echo "Error: TORCH_OPS_YAML_PATH directory does not exist: $TORCH_OPS_YAML_PATH" + echo "Please set TORCH_OPS_YAML_PATH to a valid PyTorch native operations directory" + echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" + exit 1 + fi - # Verify required YAML files exist - if [[ ! -f "$TORCH_OPS_YAML_PATH/native_functions.yaml" ]]; then - echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/native_functions.yaml" - exit 1 - fi + # Verify required YAML files exist + if [[ ! -f "$TORCH_OPS_YAML_PATH/native_functions.yaml" ]]; then + echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/native_functions.yaml" + exit 1 + fi - if [[ ! -f "$TORCH_OPS_YAML_PATH/tags.yaml" ]]; then - echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/tags.yaml" - exit 1 - fi + if [[ ! -f "$TORCH_OPS_YAML_PATH/tags.yaml" ]]; then + echo "Error: Required file not found: $TORCH_OPS_YAML_PATH/tags.yaml" + exit 1 + fi - echo "Using TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" + echo "Using TORCH_OPS_YAML_PATH: $TORCH_OPS_YAML_PATH" + CMAKE_ARGS+=("-DTORCH_OPS_YAML_PATH=$TORCH_OPS_YAML_PATH") + else + echo "WARNING: TORCH_OPS_YAML_PATH is not set. Building without PyTorch operator definitions." + echo "Some functionality may be limited. To enable full functionality, set TORCH_OPS_YAML_PATH to point to PyTorch's native operations directory." + echo "Example: export TORCH_OPS_YAML_PATH=/path/to/pytorch/aten/src/ATen/native" + fi # Build operator tests - cmake backends/vulkan/test/op_tests \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ - -DTORCH_OPS_YAML_PATH="$TORCH_OPS_YAML_PATH" \ - -DCMAKE_CXX_STANDARD=17 \ + cmake "${CMAKE_ARGS[@]}" \ -Bcmake-out/backends/vulkan/test/op_tests && \ cmake --build cmake-out/backends/vulkan/test/op_tests -j16 }