diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index f7de7e183de..b61bd4a51c0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -639,6 +639,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 4eed8b82834..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl( (void)sequence_len; - utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache =