diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index dfa357fe356..15140fceeb8 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -16,6 +16,12 @@ from torch.library import impl +aten = torch.ops.aten + +from typing import Tuple + +from torch._inductor.lowering import lowerings as L, register_lowering + try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -387,3 +393,89 @@ def custom_quantized_sdpa_meta( ) return torch.empty(query.size(), dtype=torch.float32, device="meta") + + +# 1) Define the custom op in the "executorch" namespace with name "alias" +@torch.library.custom_op("executorch::alias", mutates_args=()) +def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Runtime implementation: just return the inputs as-is. + Works for both CPU and CUDA tensors because we don't do + any device-specific work here. + """ + # no copies, just pass-through + return x, y + + +# 2) FakeTensor kernel: describes output metadata for compile-time +@custom_alias.register_fake +def _(x, y): + # For this op, outputs have exactly the same shape/dtype/device as inputs. + # We just need *dummy* tensors with that metadata. + out_x = torch.empty_like(x) + out_y = torch.empty_like(y) + return out_x, out_y + + +@register_lowering(torch.ops.executorch.alias.default) +def lowering_custom_alias(x, y): + # x, y here are IR values (Inductor's internal representation). + # Alias is logically a no-op – just pass them through. + return x, y + + +# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max +def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): + torch._assert(value.dim() == 4, "value must be 4D") + torch._assert(cache.dim() == 4, "cache must be 4D") + # Cache shape: (B, H, S_max, D) + # Value shape: (B, H, S, D) + torch._assert( + value.size(2) <= cache.size(2), + f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", + ) + torch._assert(value.size(0) == cache.size(0), "batch size mismatch") + torch._assert(value.size(1) == cache.size(1), "num heads mismatch") + torch._assert(value.size(3) == cache.size(3), "head dim mismatch") + torch._assert(value.dtype == cache.dtype, "dtype mismatch") + + +# This is cheating: we delibrately NOT mark `cache` to be mutating so that this +# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires +# no aliasing or mutation in the branches. This is fine because we only care about inference. +@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) +def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: + # Eager implementation + _validate_cross_attn_cache_params(value, cache) + + # Slice the cache to match value's sequence length and copy + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + cache[:, :, : value.size(2), :].copy_(value) + return cache + + +# Register the fake (meta) kernel +@_update_cross_attn_cache.register_fake +def _update_cross_attn_cache_fake( + value: torch.Tensor, cache: torch.Tensor +) -> torch.Tensor: + _validate_cross_attn_cache_params(value, cache) + return torch.empty_like(cache) + + +# Register Inductor lowering +@register_lowering(torch.ops.executorch.update_cross_attn_cache) +def _update_cross_attn_cache_lowering(value, cache): + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + + # We need to slice the cache along dim 2 (sequence length) + # slice(self, dim, start, end, step=1) + seq_len = value.get_size()[2] + cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) + + # Copy value into the slice + L[aten.copy_.default](cache_slice, value) + + return cache diff --git a/extension/llm/custom_ops/test_update_cross_attn_cache.py b/extension/llm/custom_ops/test_update_cross_attn_cache.py new file mode 100644 index 00000000000..e91e358384a --- /dev/null +++ b/extension/llm/custom_ops/test_update_cross_attn_cache.py @@ -0,0 +1,178 @@ +import unittest + +import torch + +# Import the custom ops to ensure they are registered +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + + +class TestUpdateCrossAttnCache(unittest.TestCase): + def test_update_cross_attn_cache(self): + + # Create tensors + # Cache: [B=2, S_max=4, H=1, D=4] + cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + # Value: [B=2, S=2, H=1, D=4] (S < S_max) + value = torch.randn(2, 2, 1, 4, dtype=torch.float32) + + # Compile a function that uses the op + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + # Run it + out = fn(value, cache) + + # Check correctness + # The first 2 elements in dim 1 should match value + torch.testing.assert_close( + cache[:, :2, :, :], value, msg="Cache slice not updated correctly" + ) + + # Make sure out and cache are close. In eager they are the same objects. + torch.testing.assert_close( + out, cache, msg="Output and cache are different objects" + ) + + # The rest should be zeros + torch.testing.assert_close( + cache[:, 2:, :, :], + torch.zeros_like(cache[:, 2:, :, :]), + msg="Rest of cache was modified", + ) + + def test_update_cross_attn_cache_in_cond(self): + # Create tensors + + # Value: [B=2, S=2, H=1, D=4] + value = torch.randn(2, 2, 1, 4, dtype=torch.float32) + # Alternative value for false branch + value_alt = torch.randn(2, 2, 1, 4, dtype=torch.float32) + + # Define a function that uses the op inside torch.cond + def fn_with_cond(pred, v1, v2, c): + def true_fn(v, cache): + return torch.ops.executorch.update_cross_attn_cache(v, cache) + + def false_fn(v, cache): + return torch.ops.executorch.update_cross_attn_cache(v, cache) + + return torch.cond(pred, true_fn, false_fn, (v1, c), (v2, c)) + + # Test with true condition + pred_true = torch.tensor(True) + cache_true = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + + # Compile the function + @torch.compile + def compiled_fn(pred, v1, v2, c): + return fn_with_cond(pred, v1, v2, c) + + # Run with true condition + compiled_fn(pred_true, value, value_alt, cache_true) + + # Check that the true branch was executed (value was used) + torch.testing.assert_close( + cache_true[:, :2, :, :], + value, + msg="Cache not updated correctly in true branch", + ) + + # Test with false condition + pred_false = torch.tensor(False) + cache_false = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + + compiled_fn(pred_false, value, value_alt, cache_false) + + # Check that the false branch was executed (value_alt was used) + torch.testing.assert_close( + cache_false[:, :2, :, :], + value_alt, + msg="Cache not updated correctly in false branch", + ) + + def test_update_cross_attn_cache_export(self): + + # Create tensors + # Cache: [B=2, S_max=4, H=1, D=4] + cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + # Value: [B=2, S=2, H=1, D=4] + value = torch.randn(2, 2, 1, 4, dtype=torch.float32) + + # Define a function that uses the op + class UpdateCacheModule(torch.nn.Module): + def forward(self, v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + module = UpdateCacheModule() + + # Export the module + exported_program = torch.export.export( + module, + (value, cache), + ) + + # Run the exported program + cache_exported = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + exported_program.module()(value, cache_exported) + + # Check correctness + torch.testing.assert_close( + cache_exported[:, :2, :, :], + value, + msg="Cache not updated correctly after export", + ) + + def test_update_cross_attn_cache_different_shapes(self): + print("Testing executorch::update_cross_attn_cache with different shapes...") + + # Test with different batch sizes and sequence lengths + test_cases = [ + # (B, S_max, S, H, D) + (1, 10, 5, 2, 8), + (4, 8, 3, 4, 16), + (2, 16, 10, 1, 32), + ] + + for B, S_max, S, H, D in test_cases: + cache = torch.zeros(B, S_max, H, D, dtype=torch.float32) + value = torch.randn(B, S, H, D, dtype=torch.float32) + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + fn(value, cache) + + # Check that the first S positions are updated + torch.testing.assert_close( + cache[:, :S, :, :], + value, + msg=f"Failed for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}", + ) + + # Check that the rest remain zeros + if S < S_max: + torch.testing.assert_close( + cache[:, S:, :, :], + torch.zeros_like(cache[:, S:, :, :]), + msg=f"Remaining cache modified for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}", + ) + + def test_update_cross_attn_cache_full_sequence(self): + + # Cache: [B=2, S_max=4, H=1, D=4] + cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32) + # Value: [B=2, S=4, H=1, D=4] (S == S_max) + value = torch.randn(2, 4, 1, 4, dtype=torch.float32) + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + fn(value, cache) + + # The entire cache should match value + torch.testing.assert_close( + cache, value, msg="Cache not fully updated when S == S_max" + )