From fd426c9c95bc7be679b05fdc1bc62113a59a4fbe Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 1 May 2025 10:19:13 -0700 Subject: [PATCH] [Executorch][llm] Make custom update cache op operate on indices This allows us to use ring buffer kv cache Differential Revision: [D73891424](https://our.internmc.facebook.com/intern/diff/D73891424/) [ghstack-poisoned] --- .../source_transformation/custom_kv_cache.py | 63 +++-- extension/llm/custom_ops/custom_ops.py | 38 ++- extension/llm/custom_ops/op_sdpa_aot.cpp | 20 +- extension/llm/custom_ops/op_update_cache.cpp | 126 +++++++-- extension/llm/custom_ops/op_update_cache.h | 1 + extension/llm/custom_ops/test_update_cache.py | 245 ++++++++++++++++++ 6 files changed, 432 insertions(+), 61 deletions(-) diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 1158a8ba7a6..9361204f6bc 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -6,7 +6,7 @@ import logging from enum import Enum -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -93,7 +93,7 @@ def _quantize(self, value): ) return quantized_value, scales, zero_points - def _quantize_and_update(self, input_pos, k_val, v_val): + def _quantize_and_update(self, input_pos, k_val, v_val, indices=None): quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) @@ -104,17 +104,28 @@ def _quantize_and_update(self, input_pos, k_val, v_val): if self.use_custom_update_cache_op: start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos + quantized_k_val, self.k_cache, start_pos, indices ) - _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) - _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos + k_scales, self.k_cache_scales, start_pos, indices + ) + _ = torch.ops.llama.update_cache( + k_zero_points, self.k_cache_zero_points, start_pos, indices + ) + _ = torch.ops.llama.update_cache( + quantized_v_val, self.v_cache, start_pos, indices + ) + _ = torch.ops.llama.update_cache( + v_scales, self.v_cache_scales, start_pos, indices + ) + _ = torch.ops.llama.update_cache( + v_zero_points, self.v_cache_zero_points, start_pos, indices ) else: + assert indices is None, "Indices not supported for this path" + # Following is also broken because in prefill input_pos = [0] + # but we need to update some slice of cache self.k_cache[:, input_pos] = quantized_k_val self.k_cache_scales[:, input_pos] = k_scales self.k_cache_zero_points[:, input_pos] = k_zero_points @@ -122,8 +133,8 @@ def _quantize_and_update(self, input_pos, k_val, v_val): self.v_cache_scales[:, input_pos] = v_scales self.v_cache_zero_points[:, input_pos] = v_zero_points - def _update_and_return_float_values(self, input_pos, k_val, v_val): - self._quantize_and_update(input_pos, k_val, v_val) + def _update_and_return_float_values(self, input_pos, k_val, v_val, indices=None): + self._quantize_and_update(input_pos, k_val, v_val, indices) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -144,24 +155,26 @@ def _update_and_return_float_values(self, input_pos, k_val, v_val): self.cache_fp_type, ) - # When returning float values we jsut use the last value + # When returning float values we just use the last value # instead of dequantized value. start_pos = input_pos[0].item() if self.use_custom_update_cache_op: - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos, indices) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos, indices) else: k_out[:, input_pos] = k_val v_out[:, input_pos] = v_val return k_out, v_out - def _update_and_return_quantized_values(self, input_pos, k_val, v_val): - self._quantize_and_update(input_pos, k_val, v_val) + def _update_and_return_quantized_values( + self, input_pos, k_val, v_val, indices=None + ): + self._quantize_and_update(input_pos, k_val, v_val, indices) return self.k_cache, self.v_cache - def update(self, input_pos, k_val, v_val): + def update(self, input_pos, k_val, v_val, indices=None): """ k_val, v_val: [B, H, S, D] return: [B, H, S, D] @@ -172,10 +185,12 @@ def update(self, input_pos, k_val, v_val): v_val = v_val.transpose(1, 2) if self.return_float_values: - k_out, v_out = self._update_and_return_float_values(input_pos, k_val, v_val) + k_out, v_out = self._update_and_return_float_values( + input_pos, k_val, v_val, indices + ) else: k_out, v_out = self._update_and_return_quantized_values( - input_pos, k_val, v_val + input_pos, k_val, v_val, indices ) return k_out.transpose(1, 2), v_out.transpose(1, 2) @@ -277,14 +292,20 @@ def __init__( ) def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [S], k_val: [B, H, S, D] k_val = k_val.transpose(1, 2) v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + + _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos, indices) + _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos, indices) + return ( self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2), diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index 6d96a926497..0abc0a9b399 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -184,6 +184,7 @@ def _validate_update_cache_params( value, cache, start_pos, + indices=None, ): seq_len = value.size(1) assert ( @@ -200,17 +201,30 @@ def _validate_update_cache_params( ), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}" torch._check_is_size(start_pos) - # Setting to arbitrary limit of 256 for now since there is no way - # to plumb this information from model config - torch._check(start_pos < cache.size(1)) - assert start_pos < cache.size( - 1 - ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" - - torch._check((start_pos + seq_len) < cache.size(1)) - assert (start_pos + seq_len) < cache.size( - 1 - ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + if indices is None: + torch._check(start_pos < cache.size(1)) + assert start_pos < cache.size( + 1 + ), f"Start position {start_pos} must be less than sequence length {cache.size(1)}" + + torch._check((start_pos + seq_len) < cache.size(1)) + assert (start_pos + seq_len) < cache.size( + 1 + ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}" + + if indices is not None: + assert ( + indices.dim() == 2 + ), f"Expected indices to be 2 dimensional but got {indices.dim()} dimensions." + assert ( + indices.dtype == torch.int64 + ), f"Expected indices to be int64 but got {indices.dtype}" + assert indices.size(0) == value.size( + 0 + ), f"Expected indices batch dimension to match value batch dimension but got {indices.size(0)} and {value.size(0)}" + assert indices.size(1) == value.size( + 1 + ), f"Expected indices sequence length dimension to match value sequence length dimension but got {indices.size(1)} and {value.size(1)}" @impl(custom_ops_lib, "update_cache", "Meta") @@ -218,11 +232,13 @@ def update_cache_meta( value, cache, start_pos, + indices=None, ): _validate_update_cache_params( value, cache, start_pos, + indices, ) # Update cache doesnt really return anything but I dont know a better diff --git a/extension/llm/custom_ops/op_sdpa_aot.cpp b/extension/llm/custom_ops/op_sdpa_aot.cpp index ff367c85c8a..8c0b7a33e03 100644 --- a/extension/llm/custom_ops/op_sdpa_aot.cpp +++ b/extension/llm/custom_ops/op_sdpa_aot.cpp @@ -122,12 +122,14 @@ Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, + const std::optional indices, Tensor& output); at::Tensor update_cache_aten( const at::Tensor& value, at::Tensor& cache, - const int64_t start_pos); + const int64_t start_pos, + const std::optional& indices); Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, @@ -324,19 +326,21 @@ Tensor& update_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, + const std::optional indices, Tensor& output) { executorch::aten::RuntimeContext context{}; return torch::executor::native::update_cache_out( - context, value, cache, start_pos, output); + context, value, cache, start_pos, indices, output); } at::Tensor update_cache_aten( const at::Tensor& value, at::Tensor& cache, - const int64_t start_pos) { + const int64_t start_pos, + const std::optional& indices) { auto output = at::empty({1}); - WRAP_TO_ATEN(update_cache_out_no_context, 3) - (value, cache, start_pos, output); + WRAP_TO_ATEN(update_cache_out_no_context, 4) + (value, cache, start_pos, indices, output); return output; } @@ -363,10 +367,10 @@ TORCH_LIBRARY_FRAGMENT(llama, m) { "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_cache(Tensor value, Tensor(a!) cache, " - "SymInt start_pos) -> Tensor"); + "SymInt start_pos, Tensor? indices=None) -> Tensor"); m.def( "update_cache.out(Tensor value, Tensor(a!) cache, " - "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); + "SymInt start_pos, Tensor? indices=None, *, Tensor(b!) out) -> Tensor(b!)"); m.def( "custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " @@ -396,7 +400,7 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl("update_cache", torch::executor::native::update_cache_aten); m.impl( "update_cache.out", - WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 3)); + WRAP_TO_ATEN(torch::executor::native::update_cache_out_no_context, 4)); m.impl( "custom_quantized_sdpa", torch::executor::native::custom_quantized_sdpa_aten); diff --git a/extension/llm/custom_ops/op_update_cache.cpp b/extension/llm/custom_ops/op_update_cache.cpp index 323b7a65ddb..9539745b59e 100644 --- a/extension/llm/custom_ops/op_update_cache.cpp +++ b/extension/llm/custom_ops/op_update_cache.cpp @@ -24,7 +24,8 @@ bool validate_cache_params( const Tensor& quantized_value, const Tensor& quantized_cache, int64_t start_pos, - int64_t seq_length) { + int64_t seq_length, + const optional& indices = nullopt) { ET_CHECK_OR_RETURN_FALSE( quantized_cache.dim() == 4, "quantized cache must be a 4D tensor"); @@ -32,11 +33,14 @@ bool validate_cache_params( quantized_value.dim() == 4, "quantized_value must be a 4D tensor"); ET_CHECK_OR_RETURN_FALSE( - start_pos < quantized_cache.size(1), - "start_pos must be less than cache size at dim 1"); + indices.has_value() || start_pos < quantized_cache.size(1), + "start_pos: %" PRId64 " must be less than cache size at dim 1: %zd", + start_pos, + quantized_cache.size(1)); ET_CHECK_OR_RETURN_FALSE( - (start_pos + seq_length) <= quantized_cache.size(1), + indices.has_value() || + (start_pos + seq_length) <= quantized_cache.size(1), "start_post + seq_length must be less than max seq length supported by cache." "start pos: %" PRId64 ", seq_length: %" PRId64 "." @@ -45,6 +49,31 @@ bool validate_cache_params( seq_length, quantized_cache.size(1)); + // Validate indices tensor if provided + if (indices.has_value()) { + const Tensor& indices_tensor = indices.value(); + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.dim() == 2, + "indices must be a 2D tensor [batch_size, seq_len]"); + + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.size(0) == quantized_value.size(0), + "indices batch dimension must match value batch dimension"); + + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.size(1) == quantized_value.size(1), + "indices sequence length dimension must match value sequence length dimension"); + + ET_CHECK_OR_RETURN_FALSE( + indices_tensor.scalar_type() == ScalarType::Long, + "indices must be of Long (int64_t) type"); + + ET_CHECK_OR_RETURN_FALSE( + is_contiguous_dim_order( + indices_tensor.dim_order().data(), indices_tensor.dim()), + "indices must be in contiguous dim order"); + } + // Make sure they are in contiguous dim order ET_CHECK_OR_RETURN_FALSE( is_contiguous_dim_order( @@ -65,27 +94,36 @@ Tensor& update_cache_out( const Tensor& value, Tensor& cache, const int64_t start_pos, + const optional& indices, Tensor& output) { (void)ctx; int64_t seq_len = value.size(1); ET_KERNEL_CHECK( ctx, - validate_cache_params(value, cache, start_pos, seq_len), + validate_cache_params(value, cache, start_pos, seq_len, indices), InvalidArgument, output); ET_CHECK_MSG( value.size(0) == cache.size(0), - "projected_value batch size should be equal to the cache batch size."); + "projected_value batch size (%zd) should be equal to the cache batch size (%zd).", + value.size(0), + cache.size(0)); ET_CHECK_MSG( value.size(2) == cache.size(2), - "projected_value number of heads should be equal to the cache number of heads."); + "projected_value number of heads (%zd) should be equal to the cache number of heads (%zd).", + value.size(2), + cache.size(2)); ET_CHECK_MSG( value.size(3) == cache.size(3), - "projected_value embedding dimension should be equal to the cache embedding dimension."); + "projected_value embedding dimension (%zd) should be equal to the cache embedding dimension (%zd).", + value.size(3), + cache.size(3)); ET_CHECK_MSG( value.element_size() == cache.element_size(), - "projected_value data type size should be equal to the cache data type size."); + "projected_value data type size (%zd) should be equal to the cache data type size (%zd).", + value.element_size(), + cache.element_size()); ET_CHECK_MSG( is_contiguous_dim_order(value.dim_order().data(), value.dim()), @@ -110,18 +148,64 @@ Tensor& update_cache_out( executorch::aten::SizesType num_bytes_to_copy = (value.numel() / value.size(0)) * value.element_size(); - for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { - executorch::aten::SizesType cache_pos_offset = - (batch_line * cache_batch_dim_stride + - start_pos * cache_seq_dim_stride) * - cache.element_size(); - executorch::aten::SizesType value_pos_offset = - (batch_line * value_batch_dim_stride) * cache.element_size(); - - std::memcpy( - (uint8_t*)cache_data + cache_pos_offset, - (uint8_t*)value_data + value_pos_offset, - num_bytes_to_copy); + if (indices.has_value()) { + // Use the provided indices tensor for each batch and sequence position + const Tensor& indices_tensor = indices.value(); + const int64_t* indices_data = indices_tensor.const_data_ptr(); + auto indices_strides = indices_tensor.strides(); + executorch::aten::StridesType indices_batch_stride = indices_strides[0]; + executorch::aten::StridesType indices_seq_stride = indices_strides[1]; + + // Calculate bytes to copy for a single token + executorch::aten::SizesType bytes_per_token = + (value.numel() / (value.size(0) * value.size(1))) * + value.element_size(); + + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + for (int64_t seq_idx = 0; seq_idx < value.size(1); ++seq_idx) { + // Get the target position from the indices tensor + int64_t target_pos = indices_data + [batch_line * indices_batch_stride + seq_idx * indices_seq_stride]; + + // Ensure the target position is valid + ET_CHECK_MSG( + target_pos >= 0 && target_pos < cache.size(1), + "Index out of bounds: %" PRId64 " not in [0, %zd)", + target_pos, + cache.size(1)); + + // Calculate offsets for cache and value + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + target_pos * cache_seq_dim_stride) * + cache.element_size(); + + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride + seq_idx * value_strides[1]) * + value.element_size(); + + // Copy a single token + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + bytes_per_token); + } + } + } else { + // Use the original implementation with start_pos + for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) { + executorch::aten::SizesType cache_pos_offset = + (batch_line * cache_batch_dim_stride + + start_pos * cache_seq_dim_stride) * + cache.element_size(); + executorch::aten::SizesType value_pos_offset = + (batch_line * value_batch_dim_stride) * cache.element_size(); + + std::memcpy( + (uint8_t*)cache_data + cache_pos_offset, + (uint8_t*)value_data + value_pos_offset, + num_bytes_to_copy); + } } // Noone uses output. Just a placeholder. diff --git a/extension/llm/custom_ops/op_update_cache.h b/extension/llm/custom_ops/op_update_cache.h index cf518b4e108..8cfc0869b7d 100644 --- a/extension/llm/custom_ops/op_update_cache.h +++ b/extension/llm/custom_ops/op_update_cache.h @@ -20,6 +20,7 @@ Tensor& update_cache_out( const Tensor& value, Tensor& cache, const int64_t start_pos, + const optional& indices, Tensor& output); } // namespace native } // namespace executor diff --git a/extension/llm/custom_ops/test_update_cache.py b/extension/llm/custom_ops/test_update_cache.py index 1d2f392c129..f5bb4b7c732 100644 --- a/extension/llm/custom_ops/test_update_cache.py +++ b/extension/llm/custom_ops/test_update_cache.py @@ -6,11 +6,28 @@ # pyre-unsafe +import multiprocessing import unittest import torch +def run_in_subprocess(target): + """ + Decorator to run the target function in a separate subprocess + so as to allow cpp code to throw runtime::abort + """ + + def wrapper(*args, **kwargs): + p = multiprocessing.Process(target=target, args=args, kwargs=kwargs) + p.start() + p.join() + if p.exitcode != 0: + raise Exception(f"Subprocess failed with exit code {p.exitcode}") + + return wrapper + + class UpdateQuantizedKVCacheTest(unittest.TestCase): def _reset(self): @@ -82,6 +99,36 @@ def _update_and_validate( self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) + def _update_with_indices_and_validate( + self, k, k_scales, k_zero_points, start_pos, indices + ): + k_cache = self.quantized_k_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + + # Update using Python indexing for reference + for batch_idx in range(self.batch_size): + for seq_idx in range(indices.size(1)): + idx = indices[batch_idx, seq_idx].item() + if idx >= 0 and idx < self.seq_len: + self.quantized_k_cache[batch_idx, idx] = k[batch_idx, seq_idx] + self.k_scales_cache[batch_idx, idx] = k_scales[batch_idx, seq_idx] + self.k_zero_points_cache[batch_idx, idx] = k_zero_points[ + batch_idx, seq_idx + ] + + # Update using custom op + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos, indices) + torch.ops.llama.update_cache( + k_zero_points, k_zero_points_cache, start_pos, indices + ) + + # Validate results + self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) + self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) + self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) + def test_update_kv_cache_simple(self): k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) @@ -94,6 +141,204 @@ def test_update_kv_cache_simple(self): k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos ) + # Tests for update_cache_with_indices functionality + + def test_basic_update_with_indices(self): + """Test basic update with indices functionality.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions 2, 5, 7 + indices = torch.tensor([[2, 5, 7]], dtype=torch.int64) + start_pos = 0 # start_pos is ignored when indices are provided + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_single_index_update(self): + """Test updating a single position with indices.""" + self._reset() + k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) + + # Update only position 4 + indices = torch.tensor([[4]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_sparse_indices(self): + """Test updating non-contiguous positions.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions 1, 4, 8 (sparse, non-contiguous) + indices = torch.tensor([[1, 4, 8]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_out_of_order_indices(self): + """Test updating positions in a non-sequential order.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update positions in reverse order: 8, 5, 2 + indices = torch.tensor([[8, 5, 2]], dtype=torch.int64) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_indices_exceeding_cache_size(self): + """Test behavior when indices exceed the cache size.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + + # Try to update positions 5, 9, 15 (where 15 is out of bounds) + indices = torch.tensor([[5, 9, 15]], dtype=torch.int64) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + + def test_negative_indices(self): + """Test behavior with negative indices.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + + # Try to update with negative indices + indices = torch.tensor([[5, -1, 8]], dtype=torch.int64) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + + def test_duplicate_indices(self): + """Test behavior when the same position is updated multiple times.""" + self._reset() + k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) + + # Update with duplicate indices - the last value should be used + indices = torch.tensor([[3, 5, 3]], dtype=torch.int64) + start_pos = 0 + + # For our reference implementation, we need to handle this case specially + k_cache = self.quantized_k_cache.clone() + v_cache = self.quantized_v_cache.clone() + k_scales_cache = self.k_scales_cache.clone() + v_scales_cache = self.v_scales_cache.clone() + k_zero_points_cache = self.k_zero_points_cache.clone() + v_zero_points_cache = self.v_zero_points_cache.clone() + + # Update using custom op + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + torch.ops.llama.update_cache(k_scales, k_scales_cache, start_pos, indices) + torch.ops.llama.update_cache( + k_zero_points, k_zero_points_cache, start_pos, indices + ) + torch.ops.llama.update_cache(v, v_cache, start_pos, indices) + torch.ops.llama.update_cache(v_scales, v_scales_cache, start_pos, indices) + torch.ops.llama.update_cache( + v_zero_points, v_zero_points_cache, start_pos, indices + ) + + # Position 3 should have the value from the last update (index 2 in the sequence) + self.assertTrue(torch.allclose(k_cache[0, 3], k[0, 2])) + self.assertTrue(torch.allclose(v_cache[0, 3], v[0, 2])) + self.assertTrue(torch.allclose(k_scales_cache[0, 3], k_scales[0, 2])) + self.assertTrue(torch.allclose(v_scales_cache[0, 3], v_scales[0, 2])) + self.assertTrue(torch.allclose(k_zero_points_cache[0, 3], k_zero_points[0, 2])) + self.assertTrue(torch.allclose(v_zero_points_cache[0, 3], v_zero_points[0, 2])) + + # Position 5 should have the value from index 1 + self.assertTrue(torch.allclose(k_cache[0, 5], k[0, 1])) + self.assertTrue(torch.allclose(v_cache[0, 5], v[0, 1])) + + def test_batched_update_with_indices(self): + """Test updating with indices in a batched setting.""" + self.batch_size = 2 + self._reset() + k = torch.randint(0, 50, (self.batch_size, 3, 8, 4), dtype=torch.int8) + k_scales = torch.rand((self.batch_size, 3, 8, 1), dtype=torch.float64) + k_zero_points = torch.randint( + 0, 20, (self.batch_size, 3, 8, 1), dtype=torch.int64 + ) + + # Different indices for each batch + indices = torch.tensor( + [[1, 4, 7], [2, 5, 8]], # indices for batch 0 # indices for batch 1 + dtype=torch.int64, + ) + start_pos = 0 + + self._update_with_indices_and_validate( + k, k_scales, k_zero_points, start_pos, indices + ) + + def test_different_seq_lengths_per_batch(self): + """Test updating with different sequence lengths per batch using padding.""" + self.batch_size = 2 + self._reset() + + # Create inputs with 3 tokens + k = torch.randint(0, 50, (self.batch_size, 3, 8, 4), dtype=torch.int8) + + # Batch 0: update 3 positions, Batch 1: update only 2 positions (use -1 as padding) + indices = torch.tensor( + [ + [1, 3, 5], # 3 valid indices for batch 0 + [2, 4, -1], # 2 valid indices for batch 1, with -1 as padding + ], + dtype=torch.int64, + ) + start_pos = 0 + + @run_in_subprocess + def run_and_catch(k, k_cache, start_pos, indices): + torch.ops.llama.update_cache(k, k_cache, start_pos, indices) + + exception_raised = False + try: + run_and_catch(k, self.quantized_k_cache, start_pos, indices) + except Exception: + exception_raised = True + self.assertTrue(exception_raised) + def test_update_kv_cache_large_update(self): self._reset() k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8)