@@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
471471 VK_CHECK_COND (graph.val_is_none (attn_mask));
472472
473473 const int64_t num_q_heads = graph.size_at <int64_t >(-2 , q_projected);
474- const int64_t max_seq_len = graph.size_at <int64_t >(-3 , q_projected);
475-
474+ int64_t max_seq_len = graph.size_at <int64_t >(-3 , q_projected);
476475 const int64_t max_context_len = graph.size_at <int32_t >(-3 , k_cache);
477476
477+ const utils::StorageType attn_weights_storage =
478+ graph.storage_type_of (q_projected);
479+
480+ // If using buffer storage for attn weights, we need to ensure that the buffer
481+ // numel limit is not exceeded. If needed, manually adjust max_seq_len based
482+ // on the buffer numel limit.
483+ if (attn_weights_storage == utils::kBuffer ) {
484+ const int64_t max_buffer_numel = graph.max_buffer_numel ();
485+ if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) {
486+ // Compute the maximum possible value for max_seq_len that will hit
487+ // the buffer numel limit.
488+ max_seq_len = max_buffer_numel / (num_q_heads * max_context_len);
489+ // Adjust down to the nearest multiple of 4 to make sure the limit is
490+ // not hit.
491+ if (max_seq_len % 4 != 0 ) {
492+ max_seq_len = (max_seq_len / 4 ) * 4 ;
493+ } else {
494+ max_seq_len -= 4 ;
495+ }
496+ }
497+ }
498+
478499 std::vector<int64_t > attn_weight_full_sizes = {
479500 1 , // batch
480501 num_q_heads,
@@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
485506 &graph,
486507 attn_weight_full_sizes,
487508 graph.dtype_of (q_projected),
488- graph. storage_type_of (q_projected) ,
509+ attn_weights_storage ,
489510 utils::kWidthPacked );
490511
491512 TmpTensor attn_weights_softmax (
492513 &graph,
493514 attn_weight_full_sizes,
494515 graph.dtype_of (q_projected),
495- graph. storage_type_of (q_projected) ,
516+ attn_weights_storage ,
496517 utils::kWidthPacked );
497518
498519 add_sdpa_compute_attn_weights_node (
@@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl(
528549
529550 utils::StorageType cache_storage = graph.storage_type_of (q_projected);
530551 const ValueRef k_cache =
531- prepack_standard ( graph, k_cache_data, cache_storage, utils::kWidthPacked );
552+ graph. add_tensor_like ( k_cache_data, cache_storage, utils::kWidthPacked );
532553 const ValueRef v_cache =
533- prepack_standard ( graph, v_cache_data, cache_storage, utils::kWidthPacked );
554+ graph. add_tensor_like ( v_cache_data, cache_storage, utils::kWidthPacked );
534555
535556 update_cache_impl (graph, {k_projected, k_cache, input_pos_symint, -1 });
536557 update_cache_impl (graph, {v_projected, v_cache, input_pos_symint, -1 });
@@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl(
573594
574595 (void )sequence_len;
575596
576- utils::StorageType cache_storage = graph.storage_type_of (q_projected);
597+ const utils::StorageType cache_storage = graph.storage_type_of (q_projected);
577598 const ValueRef k_cache =
578599 graph.add_tensor_like (k_cache_data, cache_storage, utils::kWidthPacked );
579600 const ValueRef v_cache =
0 commit comments