Skip to content

Commit 1e0568b

Browse files
meta-emilianfacebook-github-bot
authored andcommitted
Batch-aware torch.ops.llama.sdpa_with_kv_cache (#4822)
Summary: Pull Request resolved: #4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316
1 parent 618466e commit 1e0568b

File tree

2 files changed

+60
-18
lines changed

2 files changed

+60
-18
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,23 @@ void update_cache(
700700
const Tensor& cache,
701701
int64_t start_pos,
702702
int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
703+
// 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
704+
// 2) projected_value shape should be [bs, seq_len, num heads, head dim]
705+
// 3) We're updating the cache with projected_value, at position start_pos
706+
707+
ET_CHECK_MSG(
708+
projected_value.size(0) == cache.size(0),
709+
"projected_value batch size should be equal to the cache batch size.");
710+
ET_CHECK_MSG(
711+
projected_value.size(2) == cache.size(2),
712+
"projected_value number of heads should be equal to the cache number of heads.");
703713
ET_CHECK_MSG(
704-
projected_value.size(0) == 1,
705-
"projected_value must have batch size of 1");
706-
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
714+
projected_value.size(3) == cache.size(3),
715+
"projected_value embedding dimension should be equal to the cache embedding dimension.");
716+
ET_CHECK_MSG(
717+
projected_value.element_size() == cache.element_size(),
718+
"projected_value data type size should be equal to the cache data type size.");
719+
707720
ET_CHECK_MSG(
708721
is_contiguous_dim_order(
709722
projected_value.dim_order().data(), projected_value.dim()),
@@ -714,16 +727,31 @@ void update_cache(
714727
ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
715728
ET_CHECK_MSG(cache_data, "cache data is null");
716729

717-
auto strides = cache.strides();
718-
exec_aten::StridesType seq_dim_stride = strides[1];
719-
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
720-
exec_aten::SizesType pos_offset_bytes =
721-
pos_offset * projected_value.element_size();
722-
exec_aten::SizesType num_bytes =
723-
projected_value.numel() * projected_value.element_size();
724-
// NOLINTNEXTLINE
725-
std::memcpy(
726-
(uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes);
730+
auto cache_strides = cache.strides();
731+
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
732+
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
733+
734+
auto value_strides = projected_value.strides();
735+
exec_aten::StridesType value_batch_dim_stride = value_strides[0];
736+
737+
exec_aten::SizesType num_bytes_to_copy =
738+
(projected_value.numel() / projected_value.size(0)) *
739+
projected_value.element_size();
740+
741+
for (int64_t batch_line = 0; batch_line < projected_value.size(0);
742+
++batch_line) {
743+
exec_aten::SizesType cache_pos_offset =
744+
(batch_line * cache_batch_dim_stride +
745+
start_pos * cache_seq_dim_stride) *
746+
cache.element_size();
747+
exec_aten::SizesType value_pos_offset =
748+
(batch_line * value_batch_dim_stride) * cache.element_size();
749+
750+
std::memcpy(
751+
(uint8_t*)cache_data + cache_pos_offset,
752+
(uint8_t*)projected_value_data + value_pos_offset,
753+
num_bytes_to_copy);
754+
}
727755
}
728756

729757
} // anonymous namespace
@@ -859,6 +887,8 @@ Tensor& sdpa_with_kv_cache_out(
859887
sliced_key_dim_order.data(),
860888
util::kKVDim,
861889
sliced_key_strides.data());
890+
// since the cache is sliced, the batch stride needs to stay the same.
891+
sliced_key_strides[0] = key_cache.strides()[0];
862892
void* key_cache_data = key_cache.mutable_data_ptr();
863893
TensorImpl k_impl = TensorImpl(
864894
key_cache.scalar_type(),
@@ -883,6 +913,8 @@ Tensor& sdpa_with_kv_cache_out(
883913
sliced_value_dim_order.data(),
884914
util::kKVDim,
885915
sliced_value_strides.data());
916+
// since the cache is sliced, the batch stride needs to stay the same.
917+
sliced_value_strides[0] = value_cache.strides()[0];
886918
void* value_cache_data = value_cache.mutable_data_ptr();
887919
TensorImpl value_impl = TensorImpl(
888920
value_cache.scalar_type(),

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,10 @@ class SDPATestCommon(unittest.TestCase):
373373

374374
def setup_caches(self):
375375
self.k_cache = torch.zeros(
376-
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
376+
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
377377
)
378378
self.v_cache = torch.zeros(
379-
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
379+
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
380380
)
381381
self.mask = torch.full(
382382
(self.max_seq_len, self.max_seq_len),
@@ -386,6 +386,7 @@ def setup_caches(self):
386386

387387
def setUp(self):
388388
torch.manual_seed(42)
389+
self.n_batch = 2
389390
self.n_heads_kv = 32
390391
self.n_heads_q = 32
391392
self.head_dim = 128
@@ -418,19 +419,19 @@ def _test_sdpa_common(
418419
self.max_seq_len = max_seq_len
419420
self.setup_caches()
420421
q = self._scale_tensor(
421-
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
422+
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
422423
tensor_scale_max,
423424
tensor_scale_min,
424425
scale_tensors,
425426
)
426427
k = self._scale_tensor(
427-
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
428+
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
428429
tensor_scale_max,
429430
tensor_scale_min,
430431
scale_tensors,
431432
)
432433
v = self._scale_tensor(
433-
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
434+
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
434435
tensor_scale_max,
435436
tensor_scale_min,
436437
scale_tensors,
@@ -466,6 +467,15 @@ def _test_sdpa_common(
466467
scale_tensors,
467468
)
468469

470+
q = torch.rand(
471+
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
472+
)
473+
k = torch.rand(
474+
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
475+
)
476+
v = torch.rand(
477+
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
478+
)
469479
start_pos = seq_len
470480
seq_len = q.size(1)
471481
attn_mask = self.mask[start_pos : start_pos + seq_len, :]

0 commit comments

Comments
 (0)