Skip to content

Commit ae0964a

Browse files
meta-emilianfacebook-github-bot
authored andcommitted
Making update_cache update across the batch dimension. (pytorch#4822)
Summary: Pull Request resolved: pytorch#4822 This is part 1 of a multi-part commit to make torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this code needs to update the cache across the batch dimension. Differential Revision: D61605316
1 parent c252553 commit ae0964a

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,23 @@ void update_cache(
700700
const Tensor& cache,
701701
int64_t start_pos,
702702
int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
703+
// 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
704+
// 2) projected_value shape should be [bs, seq_len, num heads, head dim]
705+
// 3) We're updating the cache with projected_value, at position start_pos
706+
707+
ET_CHECK_MSG(
708+
projected_value.size(0) == cache.size(0),
709+
"projected_value batch size should be equal to the cache batch size.");
710+
ET_CHECK_MSG(
711+
projected_value.size(2) == cache.size(2),
712+
"projected_value number of heads should be equal to the cache number of heads.");
703713
ET_CHECK_MSG(
704-
projected_value.size(0) == 1,
705-
"projected_value must have batch size of 1");
706-
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
714+
projected_value.size(3) == cache.size(3),
715+
"projected_value embedding dimension should be equal to the cache embedding dimension.");
716+
ET_CHECK_MSG(
717+
projected_value.element_size() == cache.element_size(),
718+
"projected_value data type size should be equal to the cache data type size.");
719+
707720
ET_CHECK_MSG(
708721
is_contiguous_dim_order(
709722
projected_value.dim_order().data(), projected_value.dim()),
@@ -714,16 +727,30 @@ void update_cache(
714727
ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
715728
ET_CHECK_MSG(cache_data, "cache data is null");
716729

717-
auto strides = cache.strides();
718-
exec_aten::StridesType seq_dim_stride = strides[1];
719-
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
720-
exec_aten::SizesType pos_offset_bytes =
721-
pos_offset * projected_value.element_size();
722-
exec_aten::SizesType num_bytes =
723-
projected_value.numel() * projected_value.element_size();
724-
// NOLINTNEXTLINE
725-
std::memcpy(
726-
(uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes);
730+
auto cache_strides = cache.strides();
731+
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
732+
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
733+
734+
auto value_strides = projected_value.strides();
735+
exec_aten::StridesType value_batch_dim_stride = value_strides[0];
736+
737+
exec_aten::SizesType num_bytes_to_copy = projected_value.numel() /
738+
projected_value.size(0) * projected_value.element_size();
739+
740+
for (int64_t batch_line = 0; batch_line < projected_value.size(0);
741+
++batch_line) {
742+
exec_aten::SizesType cache_pos_offset =
743+
(batch_line * cache_batch_dim_stride +
744+
start_pos * cache_seq_dim_stride) *
745+
cache.element_size();
746+
exec_aten::SizesType value_pos_offset =
747+
(batch_line * value_batch_dim_stride) * cache.element_size();
748+
749+
std::memcpy(
750+
(uint8_t*)cache_data + cache_pos_offset,
751+
(uint8_t*)projected_value_data + value_pos_offset,
752+
num_bytes_to_copy);
753+
}
727754
}
728755

729756
} // anonymous namespace

0 commit comments

Comments
 (0)