From 8080cab171664f8413d2374f02474700dae05580 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:39 -0800 Subject: [PATCH 1/2] [ET-VK][ez] Refactor yaml configs for SDPA shaders Title says it all! Use the new combos codegen API which makes it easier to express generating storage type combinations. Differential Revision: [D86226138](https://our.internmc.facebook.com/intern/diff/D86226138/) ghstack-source-id: 320850476 Pull Request resolved: https://github.com/pytorch/executorch/pull/15576 --- .../graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml | 10 +++++++--- .../ops/glsl/sdpa_compute_attn_weights_tiled.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml | 10 +++++++--- .../runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl | 2 ++ .../runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml | 10 +++++++--- backends/vulkan/runtime/graph/ops/impl/SDPA.cpp | 1 + 7 files changed, 38 insertions(+), 15 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml index 6a4cffcc913..d5cadc36060 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml index 6aadbbc379e..7fc016cf3c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, K_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d - - NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_attn_weights_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml index ccebf8f7c1c..33ec2f8b322 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml @@ -12,10 +12,14 @@ sdpa_compute_out_coop: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_coop_texture3d_texture3d - - NAME: sdpa_compute_out_coop_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_coop diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml index 7fbce29e908..eac2c6f37dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml @@ -13,10 +13,14 @@ sdpa_compute_out_tiled: TILE_K4: 1 TILE_N4: 1 generate_variant_forall: + combination: + parameter_names: [IO_STORAGE, V_CACHE_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float - VALUE: half shader_variants: - - NAME: sdpa_compute_out_tiled_texture3d_texture3d - - NAME: sdpa_compute_out_tiled_buffer_texture3d - IO_STORAGE: buffer + - NAME: sdpa_compute_out_tiled diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 932696fff02..028e02d1a20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -5,6 +5,8 @@ #define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} #define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER $if INPUT_STORAGE == "buffer": #define INPUT_BUFFER diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml index 85f4ce090f8..5ec2f3e190c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml @@ -10,10 +10,14 @@ sdpa_kv_cache_update: INPUT_STORAGE: texture3d OUTPUT_STORAGE: texture3d generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: half - VALUE: float shader_variants: - - NAME: sdpa_kv_cache_update_texture3d - - NAME: sdpa_kv_cache_update_buffer - INPUT_STORAGE: buffer + - NAME: sdpa_kv_cache_update diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 8edaebd11ff..92b14c3b724 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node( const ValueRef projected, const ValueRef cache) { std::string kernel_name("sdpa_kv_cache_update"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(cache)); add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); From cf49a75b3bb0aadc056299a38bfde8ec41225f47 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 4 Nov 2025 11:52:43 -0800 Subject: [PATCH 2/2] [ET-VK][ez] Update SDPA test to be able to test different SDPA modes Title says it all! The purpose of this diff is twofold: 1. Test SDPA as both a fused operator (sdpa_with_kv_cache) and decomposed update_cache and custom_sdpa ops in order to detect possible regressions with being able to support older models 2. Make it easier to debug issues with SDPA by exposing a mode that tests only the attention weight computation. Title says it all! Update SDPA op to use buffer storage for cache tensors if projected tensors are buffer. Also included is a small change to ensure that cache tensors use the same storage type as input tensors. Differential Revision: [D86226135](https://our.internmc.facebook.com/intern/diff/D86226135/) ghstack-source-id: 320850473 Pull Request resolved: https://github.com/pytorch/executorch/pull/15577 --- backends/vulkan/op_registry.py | 2 +- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 50 ++++- backends/vulkan/test/op_tests/sdpa_test.cpp | 206 ++++++++++++++---- 3 files changed, 210 insertions(+), 48 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..7672a2d891c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op(): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_with_kv_cache_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, supports_prepacking=True, ) diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 92b14c3b724..6b4da5d95f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -526,10 +526,11 @@ void sdpa_with_kv_cache_impl( (void)sequence_len; - const ValueRef k_cache = prepack_standard( - graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked); - const ValueRef v_cache = prepack_standard( - graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked); + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -547,10 +548,51 @@ void sdpa_with_kv_cache_impl( out}); } +void compute_attn_weight_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + (void)attn_mask; + const ValueRef dropout_p = args[arg_idx++]; + (void)dropout_p; + const ValueRef is_causal = args[arg_idx++]; + (void)is_causal; + const ValueRef scale = args[arg_idx++]; + (void)scale; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const ValueRef k_cache = + prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + const ValueRef v_cache = + prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + add_sdpa_compute_attn_weights_node( + graph, q_projected, k_cache, input_pos_symint, out); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); VK_REGISTER_OP(update_cache.default, update_cache_impl); VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); + VK_REGISTER_OP( + testing.compute_attn_weight_with_kv_cache.default, + compute_attn_weight_with_kv_cache_impl); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index a94e68a53af..c3347b339a7 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -23,6 +23,24 @@ #include #include +// +// SDPA Mode Enum +// + +enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY }; + +std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) { + switch (mode) { + case SDPAMode::DECOMPOSED: + return os << "DECOMPOSED"; + case SDPAMode::FUSED: + return os << "FUSED"; + case SDPAMode::ATTN_WEIGHT_ONLY: + return os << "ATTN_WEIGHT_ONLY"; + } + return os; +} + namespace torch { namespace executor { namespace native { @@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten( const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy - const std::optional attn_mask, + const std::optional& attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy @@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl( at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, - const std::optional __attn_mask_ignored, + const std::optional& __attn_mask_ignored, const double dropout_p, const bool is_causal, - const std::optional scale) { + const std::optional scale, + SDPAMode mode = SDPAMode::DECOMPOSED) { at::Tensor attn_mask = construct_attention_mask(q_projected, key_cache, start_pos); @@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl( float scale_factor = 1.0 / sqrt(q_transposed.size(-1)); at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + return attn_weight; + } + at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1); at::Tensor out = at::matmul(attn_weight_softmax, v_transposed); @@ -268,7 +291,8 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, vkcompute::utils::StorageType storage_type, - at::ScalarType dtype = at::kFloat) { + at::ScalarType dtype = at::kFloat, + SDPAMode mode = SDPAMode::DECOMPOSED) { // compute the max sequence length int max_seq_len = start_input_pos; for (int i = 0; i < sequence_lens.size(); ++i) { @@ -296,6 +320,9 @@ void test_vulkan_sdpa( // Get reference output at::Tensor out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len}); + } // Build Vulkan SDPA graph using namespace vkcompute; @@ -330,22 +357,87 @@ void test_vulkan_sdpa( const ValueRef r_out = graph.add_tensor( out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type); - VK_GET_OP_FN("sdpa_with_kv_cache.default") - (graph, - { - r_q.value, - r_k.value, - r_v.value, - r_k_cache_data, - r_v_cache_data, - r_input_pos_symint, - kDummyValueRef, // sequence_len - kDummyValueRef, // attn_mask - kDummyValueRef, // dropout_p - kDummyValueRef, // is_causal - kDummyValueRef, // scale - r_out, - }); + switch (mode) { + case SDPAMode::DECOMPOSED: { + const ValueRef r_k_cache = graph.add_tensor( + k_cache_data.sizes().vec(), + from_at_scalartype(k_cache_data.scalar_type()), + storage_type); + const ValueRef r_v_cache = graph.add_tensor( + v_cache_data.sizes().vec(), + from_at_scalartype(v_cache_data.scalar_type()), + storage_type); + const ValueRef r_dummy_out = graph.add_tensor( + {1}, from_at_scalartype(out.scalar_type()), utils::kBuffer); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_k.value, + r_k_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("update_cache.default") + (graph, + { + r_v.value, + r_v_cache, + r_input_pos_symint, + r_dummy_out, + }); + VK_GET_OP_FN("llama.custom_sdpa.default") + (graph, + { + r_q.value, + r_k_cache, + r_v_cache, + r_input_pos_symint, + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + } break; + case SDPAMode::FUSED: + VK_GET_OP_FN("sdpa_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + case SDPAMode::ATTN_WEIGHT_ONLY: + VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default") + (graph, + { + r_q.value, + r_k.value, + r_v.value, + r_k_cache_data, + r_v_cache_data, + r_input_pos_symint, + kDummyValueRef, // sequence_len + kDummyValueRef, // attn_mask + kDummyValueRef, // dropout_p + kDummyValueRef, // is_causal + kDummyValueRef, // scale + r_out, + }); + break; + default: + VK_THROW("Unsupported SDPA mode"); + } ValueRef staging_out = graph.set_output_tensor(r_out); @@ -378,7 +470,7 @@ void test_vulkan_sdpa( v = at::rand_like(k); at::Tensor reference_out = sdpa_reference_impl( - q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}); + q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode); graph.set_symint(r_input_pos_symint, input_pos); graph.resize_input(0, q.sizes().vec()); @@ -393,15 +485,38 @@ void test_vulkan_sdpa( graph.execute(); - out = at::empty_like(q); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + const int context_len = input_pos + seq_len; + const int context_len_align_up4 = (context_len + 3) & ~3; + const int seq_len_align_up4 = (seq_len + 3) & ~3; + + out = at::empty( + {batch_size, num_heads, seq_len_align_up4, context_len_align_up4}, + q.options()); + } else { + out = at::empty_like(q); + } EXTRACT_TENSOR(out); + if (mode == SDPAMode::ATTN_WEIGHT_ONLY) { + // Index vk_out to only include the relevant seq_len and context_len + // dimensions + int context_len = input_pos + seq_len; + vk_out = vk_out.index( + {at::indexing::Slice(), + at::indexing::Slice(), + at::indexing::Slice(0, seq_len), + at::indexing::Slice(0, context_len)}); + } + const bool output_correct = at::allclose(reference_out, vk_out); if (!output_correct) { // Print only differing tensor elements side by side for easier comparison auto ref_flat = reference_out.flatten(); auto vk_flat = vk_out.flatten(); auto numel = ref_flat.numel(); + std::cout << "While testing " << mode << " mode with " << storage_type + << " storage" << std::endl; std::cout << "reference_out\tvk_out\tindex" << std::endl; int first_diff_idx = -1; auto sizes = reference_out.sizes(); @@ -466,27 +581,32 @@ void test_vulkan_sdpa( const int num_kv_heads, const int batch_size, at::ScalarType dtype = at::kFloat) { - // Test texture - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kTexture3D, - dtype); - - // Test buffer - test_vulkan_sdpa( - start_input_pos, - sequence_lens, - head_dim, - num_heads, - num_kv_heads, - batch_size, - vkcompute::utils::kBuffer, - dtype); + for (SDPAMode mode : + {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) { + // Test texture + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kTexture3D, + dtype, + mode); + + // Test buffer + test_vulkan_sdpa( + start_input_pos, + sequence_lens, + head_dim, + num_heads, + num_kv_heads, + batch_size, + vkcompute::utils::kBuffer, + dtype, + mode); + } } TEST(VulkanSDPATest, test_sdpa_op_small_params) {