From 9508947947ba5a5c841cb45b4047e1f25f7777fd Mon Sep 17 00:00:00 2001 From: Emilian Stoimenov Date: Tue, 17 Sep 2024 14:50:52 -0700 Subject: [PATCH] Batch-aware torch.ops.llama.sdpa_with_kv_cache (#4822) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316 --- extension/llm/custom_ops/op_sdpa.cpp | 58 ++++++++++++++----- .../llm/custom_ops/test_sdpa_with_kv_cache.py | 27 +++++---- 2 files changed, 62 insertions(+), 23 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 0bb168bdadb..e8a53a41312 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -700,10 +700,23 @@ void update_cache( const Tensor& cache, int64_t start_pos, int64_t seq_length) { // NOLINT: unused parameter 'seq_length' + // 1) Cache shape should be [bs, max_seq_len, num heads, head dim] + // 2) projected_value shape should be [bs, seq_len, num heads, head dim] + // 3) We're updating the cache with projected_value, at position start_pos + + ET_CHECK_MSG( + projected_value.size(0) == cache.size(0), + "projected_value batch size should be equal to the cache batch size."); + ET_CHECK_MSG( + projected_value.size(2) == cache.size(2), + "projected_value number of heads should be equal to the cache number of heads."); ET_CHECK_MSG( - projected_value.size(0) == 1, - "projected_value must have batch size of 1"); - ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1"); + projected_value.size(3) == cache.size(3), + "projected_value embedding dimension should be equal to the cache embedding dimension."); + ET_CHECK_MSG( + projected_value.element_size() == cache.element_size(), + "projected_value data type size should be equal to the cache data type size."); + ET_CHECK_MSG( is_contiguous_dim_order( projected_value.dim_order().data(), projected_value.dim()), @@ -714,16 +727,31 @@ void update_cache( ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null"); ET_CHECK_MSG(cache_data, "cache data is null"); - auto strides = cache.strides(); - exec_aten::StridesType seq_dim_stride = strides[1]; - exec_aten::SizesType pos_offset = start_pos * seq_dim_stride; - exec_aten::SizesType pos_offset_bytes = - pos_offset * projected_value.element_size(); - exec_aten::SizesType num_bytes = - projected_value.numel() * projected_value.element_size(); - // NOLINTNEXTLINE - std::memcpy( - (uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes); + auto cache_strides = cache.strides(); + exec_aten::StridesType cache_batch_dim_stride = cache_strides[0]; + exec_aten::StridesType cache_seq_dim_stride = cache_strides[1]; + + auto value_strides = projected_value.strides(); + exec_aten::StridesType value_batch_dim_stride = value_strides[0]; + + exec_aten::SizesType num_bytes_to_copy = + (projected_value.numel() / projected_value.size(0)) * + projected_value.element_size(); + + for (int64_t batch_line = 0; batch_line < projected_value.size(0); + ++batch_line) { + exec_aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + start_pos * cache_seq_dim_stride) * + cache.element_size(); + exec_aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)projected_value_data + value_pos_offset, + num_bytes_to_copy); + } } } // anonymous namespace @@ -859,6 +887,8 @@ Tensor& sdpa_with_kv_cache_out( sliced_key_dim_order.data(), util::kKVDim, sliced_key_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_key_strides[0] = key_cache.strides()[0]; void* key_cache_data = key_cache.mutable_data_ptr(); TensorImpl k_impl = TensorImpl( key_cache.scalar_type(), @@ -883,6 +913,8 @@ Tensor& sdpa_with_kv_cache_out( sliced_value_dim_order.data(), util::kKVDim, sliced_value_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_value_strides[0] = value_cache.strides()[0]; void* value_cache_data = value_cache.mutable_data_ptr(); TensorImpl value_impl = TensorImpl( value_cache.scalar_type(), diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index dd63c68f138..bfd64cb8975 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -373,10 +373,10 @@ class SDPATestCommon(unittest.TestCase): def setup_caches(self): self.k_cache = torch.zeros( - (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) ) self.v_cache = torch.zeros( - (1, self.max_seq_len, self.n_heads_kv, self.head_dim) + (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) ) self.mask = torch.full( (self.max_seq_len, self.max_seq_len), @@ -386,6 +386,7 @@ def setup_caches(self): def setUp(self): torch.manual_seed(42) + self.n_batch = 5 self.n_heads_kv = 32 self.n_heads_q = 32 self.head_dim = 128 @@ -410,27 +411,27 @@ def _test_sdpa_common( scale_tensors=False, ): # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests - tensor_scale_max = 20 - tensor_scale_min = -20 + tensor_scale_max = 15 + tensor_scale_min = -15 self.n_heads_kv = n_heads_kv self.n_heads_q = n_heads_q self.head_dim = head_dim self.max_seq_len = max_seq_len self.setup_caches() q = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, ) k = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, ) v = self._scale_tensor( - torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)), + torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), tensor_scale_max, tensor_scale_min, scale_tensors, @@ -448,19 +449,25 @@ def _test_sdpa_common( self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) q = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors, ) k = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors, ) v = self._scale_tensor( - torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)), + torch.rand( + (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) + ), tensor_scale_max, tensor_scale_min, scale_tensors,