Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,23 @@ void update_cache(
const Tensor& cache,
int64_t start_pos,
int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
// 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
// 2) projected_value shape should be [bs, seq_len, num heads, head dim]
// 3) We're updating the cache with projected_value, at position start_pos

ET_CHECK_MSG(
projected_value.size(0) == cache.size(0),
"projected_value batch size should be equal to the cache batch size.");
ET_CHECK_MSG(
projected_value.size(2) == cache.size(2),
"projected_value number of heads should be equal to the cache number of heads.");
ET_CHECK_MSG(
projected_value.size(0) == 1,
"projected_value must have batch size of 1");
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
projected_value.size(3) == cache.size(3),
"projected_value embedding dimension should be equal to the cache embedding dimension.");
ET_CHECK_MSG(
projected_value.element_size() == cache.element_size(),
"projected_value data type size should be equal to the cache data type size.");

ET_CHECK_MSG(
is_contiguous_dim_order(
projected_value.dim_order().data(), projected_value.dim()),
Expand All @@ -714,16 +727,31 @@ void update_cache(
ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
ET_CHECK_MSG(cache_data, "cache data is null");

auto strides = cache.strides();
exec_aten::StridesType seq_dim_stride = strides[1];
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
exec_aten::SizesType pos_offset_bytes =
pos_offset * projected_value.element_size();
exec_aten::SizesType num_bytes =
projected_value.numel() * projected_value.element_size();
// NOLINTNEXTLINE
std::memcpy(
(uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes);
auto cache_strides = cache.strides();
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];

auto value_strides = projected_value.strides();
exec_aten::StridesType value_batch_dim_stride = value_strides[0];

exec_aten::SizesType num_bytes_to_copy =
(projected_value.numel() / projected_value.size(0)) *
projected_value.element_size();

for (int64_t batch_line = 0; batch_line < projected_value.size(0);
++batch_line) {
exec_aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
start_pos * cache_seq_dim_stride) *
cache.element_size();
exec_aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride) * cache.element_size();

std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)projected_value_data + value_pos_offset,
num_bytes_to_copy);
}
}

} // anonymous namespace
Expand Down Expand Up @@ -859,6 +887,8 @@ Tensor& sdpa_with_kv_cache_out(
sliced_key_dim_order.data(),
util::kKVDim,
sliced_key_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_key_strides[0] = key_cache.strides()[0];
void* key_cache_data = key_cache.mutable_data_ptr();
TensorImpl k_impl = TensorImpl(
key_cache.scalar_type(),
Expand All @@ -883,6 +913,8 @@ Tensor& sdpa_with_kv_cache_out(
sliced_value_dim_order.data(),
util::kKVDim,
sliced_value_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_value_strides[0] = value_cache.strides()[0];
void* value_cache_data = value_cache.mutable_data_ptr();
TensorImpl value_impl = TensorImpl(
value_cache.scalar_type(),
Expand Down
27 changes: 17 additions & 10 deletions extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ class SDPATestCommon(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.v_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.mask = torch.full(
(self.max_seq_len, self.max_seq_len),
Expand All @@ -386,6 +386,7 @@ def setup_caches(self):

def setUp(self):
torch.manual_seed(42)
self.n_batch = 5
self.n_heads_kv = 32
self.n_heads_q = 32
self.head_dim = 128
Expand All @@ -410,27 +411,27 @@ def _test_sdpa_common(
scale_tensors=False,
):
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
tensor_scale_max = 20
tensor_scale_min = -20
tensor_scale_max = 15
tensor_scale_min = -15
self.n_heads_kv = n_heads_kv
self.n_heads_q = n_heads_q
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.setup_caches()
q = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
Expand All @@ -448,19 +449,25 @@ def _test_sdpa_common(
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))

q = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
Expand Down