Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

from torch.library import impl

aten = torch.ops.aten

from typing import Tuple

from torch._inductor.lowering import lowerings as L, register_lowering

try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
Expand Down Expand Up @@ -387,3 +393,89 @@ def custom_quantized_sdpa_meta(
)

return torch.empty(query.size(), dtype=torch.float32, device="meta")


# 1) Define the custom op in the "executorch" namespace with name "alias"
@torch.library.custom_op("executorch::alias", mutates_args=())
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Runtime implementation: just return the inputs as-is.
Works for both CPU and CUDA tensors because we don't do
any device-specific work here.
"""
# no copies, just pass-through
return x, y


# 2) FakeTensor kernel: describes output metadata for compile-time
@custom_alias.register_fake
def _(x, y):
# For this op, outputs have exactly the same shape/dtype/device as inputs.
# We just need *dummy* tensors with that metadata.
out_x = torch.empty_like(x)
out_y = torch.empty_like(y)
return out_x, out_y


@register_lowering(torch.ops.executorch.alias.default)
def lowering_custom_alias(x, y):
# x, y here are IR values (Inductor's internal representation).
# Alias is logically a no-op – just pass them through.
return x, y


# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
torch._assert(value.dim() == 4, "value must be 4D")
torch._assert(cache.dim() == 4, "cache must be 4D")
# Cache shape: (B, H, S_max, D)
# Value shape: (B, H, S, D)
torch._assert(
value.size(2) <= cache.size(2),
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
)
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
torch._assert(value.dtype == cache.dtype, "dtype mismatch")


# This is cheating: we delibrately NOT mark `cache` to be mutating so that this
# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires
# no aliasing or mutation in the branches. This is fine because we only care about inference.
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
# Eager implementation
_validate_cross_attn_cache_params(value, cache)

# Slice the cache to match value's sequence length and copy
# cache shape: [B, H, S_max, D]
# value shape: [B, H, S, D]
cache[:, :, : value.size(2), :].copy_(value)
return cache


# Register the fake (meta) kernel
@_update_cross_attn_cache.register_fake
def _update_cross_attn_cache_fake(
value: torch.Tensor, cache: torch.Tensor
) -> torch.Tensor:
_validate_cross_attn_cache_params(value, cache)
return torch.empty_like(cache)


# Register Inductor lowering
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
def _update_cross_attn_cache_lowering(value, cache):
# cache shape: [B, H, S_max, D]
# value shape: [B, H, S, D]

# We need to slice the cache along dim 2 (sequence length)
# slice(self, dim, start, end, step=1)
seq_len = value.get_size()[2]
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)

# Copy value into the slice
L[aten.copy_.default](cache_slice, value)

return cache
178 changes: 178 additions & 0 deletions extension/llm/custom_ops/test_update_cross_attn_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import unittest

import torch

# Import the custom ops to ensure they are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401


class TestUpdateCrossAttnCache(unittest.TestCase):
def test_update_cross_attn_cache(self):

# Create tensors
# Cache: [B=2, S_max=4, H=1, D=4]
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
# Value: [B=2, S=2, H=1, D=4] (S < S_max)
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)

# Compile a function that uses the op
@torch.compile
def fn(v, c):
return torch.ops.executorch.update_cross_attn_cache(v, c)

# Run it
out = fn(value, cache)

# Check correctness
# The first 2 elements in dim 1 should match value
torch.testing.assert_close(
cache[:, :2, :, :], value, msg="Cache slice not updated correctly"
)

# Make sure out and cache are close. In eager they are the same objects.
torch.testing.assert_close(
out, cache, msg="Output and cache are different objects"
)

# The rest should be zeros
torch.testing.assert_close(
cache[:, 2:, :, :],
torch.zeros_like(cache[:, 2:, :, :]),
msg="Rest of cache was modified",
)

def test_update_cross_attn_cache_in_cond(self):
# Create tensors

# Value: [B=2, S=2, H=1, D=4]
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)
# Alternative value for false branch
value_alt = torch.randn(2, 2, 1, 4, dtype=torch.float32)

# Define a function that uses the op inside torch.cond
def fn_with_cond(pred, v1, v2, c):
def true_fn(v, cache):
return torch.ops.executorch.update_cross_attn_cache(v, cache)

def false_fn(v, cache):
return torch.ops.executorch.update_cross_attn_cache(v, cache)

return torch.cond(pred, true_fn, false_fn, (v1, c), (v2, c))

# Test with true condition
pred_true = torch.tensor(True)
cache_true = torch.zeros(2, 4, 1, 4, dtype=torch.float32)

# Compile the function
@torch.compile
def compiled_fn(pred, v1, v2, c):
return fn_with_cond(pred, v1, v2, c)

# Run with true condition
compiled_fn(pred_true, value, value_alt, cache_true)

# Check that the true branch was executed (value was used)
torch.testing.assert_close(
cache_true[:, :2, :, :],
value,
msg="Cache not updated correctly in true branch",
)

# Test with false condition
pred_false = torch.tensor(False)
cache_false = torch.zeros(2, 4, 1, 4, dtype=torch.float32)

compiled_fn(pred_false, value, value_alt, cache_false)

# Check that the false branch was executed (value_alt was used)
torch.testing.assert_close(
cache_false[:, :2, :, :],
value_alt,
msg="Cache not updated correctly in false branch",
)

def test_update_cross_attn_cache_export(self):

# Create tensors
# Cache: [B=2, S_max=4, H=1, D=4]
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
# Value: [B=2, S=2, H=1, D=4]
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)

# Define a function that uses the op
class UpdateCacheModule(torch.nn.Module):
def forward(self, v, c):
return torch.ops.executorch.update_cross_attn_cache(v, c)

module = UpdateCacheModule()

# Export the module
exported_program = torch.export.export(
module,
(value, cache),
)

# Run the exported program
cache_exported = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
exported_program.module()(value, cache_exported)

# Check correctness
torch.testing.assert_close(
cache_exported[:, :2, :, :],
value,
msg="Cache not updated correctly after export",
)

def test_update_cross_attn_cache_different_shapes(self):
print("Testing executorch::update_cross_attn_cache with different shapes...")

# Test with different batch sizes and sequence lengths
test_cases = [
# (B, S_max, S, H, D)
(1, 10, 5, 2, 8),
(4, 8, 3, 4, 16),
(2, 16, 10, 1, 32),
]

for B, S_max, S, H, D in test_cases:
cache = torch.zeros(B, S_max, H, D, dtype=torch.float32)
value = torch.randn(B, S, H, D, dtype=torch.float32)

@torch.compile
def fn(v, c):
return torch.ops.executorch.update_cross_attn_cache(v, c)

fn(value, cache)

# Check that the first S positions are updated
torch.testing.assert_close(
cache[:, :S, :, :],
value,
msg=f"Failed for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}",
)

# Check that the rest remain zeros
if S < S_max:
torch.testing.assert_close(
cache[:, S:, :, :],
torch.zeros_like(cache[:, S:, :, :]),
msg=f"Remaining cache modified for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}",
)

def test_update_cross_attn_cache_full_sequence(self):

# Cache: [B=2, S_max=4, H=1, D=4]
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
# Value: [B=2, S=4, H=1, D=4] (S == S_max)
value = torch.randn(2, 4, 1, 4, dtype=torch.float32)

@torch.compile
def fn(v, c):
return torch.ops.executorch.update_cross_attn_cache(v, c)

fn(value, cache)

# The entire cache should match value
torch.testing.assert_close(
cache, value, msg="Cache not fully updated when S == S_max"
)
Loading