Skip to content

Commit 0e66ad3

Browse files
authored
[ET-VK][ez] Ensure that attn_weight buffers do not exceed GPU buffer numel limit
Differential Revision: D86443407 Pull Request resolved: #15651
1 parent 944cccf commit 0e66ad3

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ class ComputeGraph final {
639639

640640
bool device_name_contains(const char* substr);
641641

642+
int64_t max_buffer_numel() {
643+
return static_cast<int64_t>(context_->adapter_ptr()->max_buffer_numel());
644+
}
645+
642646
//
643647
// Graph Building
644648
//

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
471471
VK_CHECK_COND(graph.val_is_none(attn_mask));
472472

473473
const int64_t num_q_heads = graph.size_at<int64_t>(-2, q_projected);
474-
const int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);
475-
474+
int64_t max_seq_len = graph.size_at<int64_t>(-3, q_projected);
476475
const int64_t max_context_len = graph.size_at<int32_t>(-3, k_cache);
477476

477+
const utils::StorageType attn_weights_storage =
478+
graph.storage_type_of(q_projected);
479+
480+
// If using buffer storage for attn weights, we need to ensure that the buffer
481+
// numel limit is not exceeded. If needed, manually adjust max_seq_len based
482+
// on the buffer numel limit.
483+
if (attn_weights_storage == utils::kBuffer) {
484+
const int64_t max_buffer_numel = graph.max_buffer_numel();
485+
if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) {
486+
// Compute the maximum possible value for max_seq_len that will hit
487+
// the buffer numel limit.
488+
max_seq_len = max_buffer_numel / (num_q_heads * max_context_len);
489+
// Adjust down to the nearest multiple of 4 to make sure the limit is
490+
// not hit.
491+
if (max_seq_len % 4 != 0) {
492+
max_seq_len = (max_seq_len / 4) * 4;
493+
} else {
494+
max_seq_len -= 4;
495+
}
496+
}
497+
}
498+
478499
std::vector<int64_t> attn_weight_full_sizes = {
479500
1, // batch
480501
num_q_heads,
@@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
485506
&graph,
486507
attn_weight_full_sizes,
487508
graph.dtype_of(q_projected),
488-
graph.storage_type_of(q_projected),
509+
attn_weights_storage,
489510
utils::kWidthPacked);
490511

491512
TmpTensor attn_weights_softmax(
492513
&graph,
493514
attn_weight_full_sizes,
494515
graph.dtype_of(q_projected),
495-
graph.storage_type_of(q_projected),
516+
attn_weights_storage,
496517
utils::kWidthPacked);
497518

498519
add_sdpa_compute_attn_weights_node(
@@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl(
528549

529550
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
530551
const ValueRef k_cache =
531-
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
552+
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
532553
const ValueRef v_cache =
533-
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
554+
graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked);
534555

535556
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
536557
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
@@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl(
573594

574595
(void)sequence_len;
575596

576-
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
597+
const utils::StorageType cache_storage = graph.storage_type_of(q_projected);
577598
const ValueRef k_cache =
578599
graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked);
579600
const ValueRef v_cache =

0 commit comments

Comments
 (0)