From 5cba1fb4dd0da95ee04270b41345ae2d4b326d64 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 5 Nov 2025 12:05:22 -0500 Subject: [PATCH] [ET-VK][ez] Don't copy zeros for cache tensors Currently, cache tensors for SDPA are prepacked even though the mutable buffer data just contains zeros. For fused SDPA, this step can be skipped. Differential Revision: [D86226137](https://our.internmc.facebook.com/intern/diff/D86226137/) ghstack-source-id: 320850472 Pull Request resolved: https://github.com/pytorch/executorch/pull/15580 --- .../graph/ops/glsl/sdpa_kv_cache_update.glsl | 16 ++++++++++++---- backends/vulkan/runtime/graph/ops/impl/SDPA.cpp | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 028e02d1a20..b780cdce6fe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -14,6 +14,10 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; +#define DEBUG_MODE + +#extension GL_EXT_debug_printf : enable + #include "common.glslh" ${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} @@ -80,13 +84,17 @@ void main() { const int S = projected_sizes.z; const int H = projected_sizes.y; - if (d4 >= D4 || s >= S || h >= H) { + const int c = s + input_pos; // idx along max_context_len dim + const int C = cache_sizes.z; + + if (d4 >= D4 || c >= C || h >= H) { return; } - const int c = s + input_pos; // idx along max_context_len dim - const int C = cache_sizes.y; + IN_VEC4_T in_texel = IN_VEC4_T(0.0); + if (s < S) { + in_texel = read_projected_d4(d4, h, s, D4, H, S); + } - IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S); write_cache_d4(in_texel, d4, c, h, D4, C, H); } diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index f514530f175..4eed8b82834 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -575,9 +575,9 @@ void compute_attn_weight_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});