diff --git a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl index 028e02d1a20..b780cdce6fe 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl @@ -14,6 +14,10 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; +#define DEBUG_MODE + +#extension GL_EXT_debug_printf : enable + #include "common.glslh" ${layout_declare_tensor(B, "w", "t_cache", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} @@ -80,13 +84,17 @@ void main() { const int S = projected_sizes.z; const int H = projected_sizes.y; - if (d4 >= D4 || s >= S || h >= H) { + const int c = s + input_pos; // idx along max_context_len dim + const int C = cache_sizes.z; + + if (d4 >= D4 || c >= C || h >= H) { return; } - const int c = s + input_pos; // idx along max_context_len dim - const int C = cache_sizes.y; + IN_VEC4_T in_texel = IN_VEC4_T(0.0); + if (s < S) { + in_texel = read_projected_d4(d4, h, s, D4, H, S); + } - IN_VEC4_T in_texel = read_projected_d4(d4, h, s, D4, H, S); write_cache_d4(in_texel, d4, c, h, D4, C, H); } diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index f514530f175..4eed8b82834 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -575,9 +575,9 @@ void compute_attn_weight_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});