Skip to content

Commit 00e89d7

Browse files
committed
Custom op to update cache for torch.cond
torch.cond doesn't take aliasing or mutations. Adding 2 ops for supporting conditionally updating kv cache: * `executorch::alias`: takes 2 tensors and return the same 2 tensors. * `executorch::cross_attn_cache_update`: takes a tensor `cache` and a tensor `value`, in place copy `value` into `cache`. With these 2 ops, we can rewrite the model definition from: ```py if is_cross_attention and past_key_values and is_updated: # reuse k,v, cross_attentions key_states = past_key_values.layers[self.layer_idx].keys value_states = past_key_values.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) ``` Into: ```py def use_cached_kv( cached_keys: Tensor, cached_values: Tensor, key_value_states: Tensor, ) -> tuple[Tensor, Tensor]: # Just reuse cached K/V return torch.ops.executorch.alias(cached_keys, cached_values) def recompute_kv( cached_keys: Tensor, # unused cached_values: Tensor, # unused key_value_states: Tensor, ) -> tuple[Tensor, Tensor]: # Compute fresh K/V (export-friendly: no cache mutation in here) key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys) v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values) return k, v if past_key_values is not None and self.layer_idx is not None: # Grab cached tensors (these are Tensors, so they are OK for export) cached_keys = past_key_values.layers[self.layer_idx].keys cached_values = past_key_values.layers[self.layer_idx].values # Tensor predicate: True if any element is non-zero # Result is a 0-dim bool tensor suitable for torch.cond cache_is_initialized = (cached_keys != 0).any() # Use torch.cond to select branch in a traceable way. # All operands must be (nested) tensors or simple Python values. key_states, value_states = torch.cond( cache_is_initialized, use_cached_kv, recompute_kv, operands=(cached_keys, cached_values, key_value_states), ) ```
1 parent fee1b2d commit 00e89d7

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616

1717
from torch.library import impl
1818

19+
aten = torch.ops.aten
20+
21+
from typing import Tuple
22+
23+
from torch._inductor.lowering import lowerings as L, register_lowering
24+
1925
try:
2026
op = torch.ops.llama.sdpa_with_kv_cache.default
2127
assert op is not None
@@ -387,3 +393,89 @@ def custom_quantized_sdpa_meta(
387393
)
388394

389395
return torch.empty(query.size(), dtype=torch.float32, device="meta")
396+
397+
398+
# 1) Define the custom op in the "executorch" namespace with name "alias"
399+
@torch.library.custom_op("executorch::alias", mutates_args=())
400+
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
401+
"""
402+
Runtime implementation: just return the inputs as-is.
403+
Works for both CPU and CUDA tensors because we don't do
404+
any device-specific work here.
405+
"""
406+
# no copies, just pass-through
407+
return x, y
408+
409+
410+
# 2) FakeTensor kernel: describes output metadata for compile-time
411+
@custom_alias.register_fake
412+
def _(x, y):
413+
# For this op, outputs have exactly the same shape/dtype/device as inputs.
414+
# We just need *dummy* tensors with that metadata.
415+
out_x = torch.empty_like(x)
416+
out_y = torch.empty_like(y)
417+
return out_x, out_y
418+
419+
420+
@register_lowering(torch.ops.executorch.alias.default)
421+
def lowering_custom_alias(x, y):
422+
# x, y here are IR values (Inductor's internal representation).
423+
# Alias is logically a no-op – just pass them through.
424+
return x, y
425+
426+
427+
# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
428+
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
429+
torch._assert(value.dim() == 4, "value must be 4D")
430+
torch._assert(cache.dim() == 4, "cache must be 4D")
431+
# Cache shape: (B, H, S_max, D)
432+
# Value shape: (B, H, S, D)
433+
torch._assert(
434+
value.size(2) <= cache.size(2),
435+
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
436+
)
437+
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
438+
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
439+
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
440+
torch._assert(value.dtype == cache.dtype, "dtype mismatch")
441+
442+
443+
# This is cheating: we delibrately NOT mark `cache` to be mutating so that this
444+
# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires
445+
# no aliasing or mutation in the branches. This is fine because we only care about inference.
446+
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
447+
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
448+
# Eager implementation
449+
_validate_cross_attn_cache_params(value, cache)
450+
451+
# Slice the cache to match value's sequence length and copy
452+
# cache shape: [B, H, S_max, D]
453+
# value shape: [B, H, S, D]
454+
cache[:, :, : value.size(2), :].copy_(value)
455+
return cache
456+
457+
458+
# Register the fake (meta) kernel
459+
@_update_cross_attn_cache.register_fake
460+
def _update_cross_attn_cache_fake(
461+
value: torch.Tensor, cache: torch.Tensor
462+
) -> torch.Tensor:
463+
_validate_cross_attn_cache_params(value, cache)
464+
return torch.empty_like(cache)
465+
466+
467+
# Register Inductor lowering
468+
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
469+
def _update_cross_attn_cache_lowering(value, cache):
470+
# cache shape: [B, H, S_max, D]
471+
# value shape: [B, H, S, D]
472+
473+
# We need to slice the cache along dim 2 (sequence length)
474+
# slice(self, dim, start, end, step=1)
475+
seq_len = value.get_size()[2]
476+
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)
477+
478+
# Copy value into the slice
479+
L[aten.copy_.default](cache_slice, value)
480+
481+
return cache
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import unittest
2+
3+
import torch
4+
5+
# Import the custom ops to ensure they are registered
6+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
7+
8+
9+
class TestUpdateCrossAttnCache(unittest.TestCase):
10+
def test_update_cross_attn_cache(self):
11+
12+
# Create tensors
13+
# Cache: [B=2, S_max=4, H=1, D=4]
14+
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
15+
# Value: [B=2, S=2, H=1, D=4] (S < S_max)
16+
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)
17+
18+
# Compile a function that uses the op
19+
@torch.compile
20+
def fn(v, c):
21+
return torch.ops.executorch.update_cross_attn_cache(v, c)
22+
23+
# Run it
24+
out = fn(value, cache)
25+
26+
# Check correctness
27+
# The first 2 elements in dim 1 should match value
28+
torch.testing.assert_close(
29+
cache[:, :2, :, :], value, msg="Cache slice not updated correctly"
30+
)
31+
32+
# Make sure out and cache are close. In eager they are the same objects.
33+
torch.testing.assert_close(
34+
out, cache, msg="Output and cache are different objects"
35+
)
36+
37+
# The rest should be zeros
38+
torch.testing.assert_close(
39+
cache[:, 2:, :, :],
40+
torch.zeros_like(cache[:, 2:, :, :]),
41+
msg="Rest of cache was modified",
42+
)
43+
44+
def test_update_cross_attn_cache_in_cond(self):
45+
# Create tensors
46+
47+
# Value: [B=2, S=2, H=1, D=4]
48+
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)
49+
# Alternative value for false branch
50+
value_alt = torch.randn(2, 2, 1, 4, dtype=torch.float32)
51+
52+
# Define a function that uses the op inside torch.cond
53+
def fn_with_cond(pred, v1, v2, c):
54+
def true_fn(v, cache):
55+
return torch.ops.executorch.update_cross_attn_cache(v, cache)
56+
57+
def false_fn(v, cache):
58+
return torch.ops.executorch.update_cross_attn_cache(v, cache)
59+
60+
return torch.cond(pred, true_fn, false_fn, (v1, c), (v2, c))
61+
62+
# Test with true condition
63+
pred_true = torch.tensor(True)
64+
cache_true = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
65+
66+
# Compile the function
67+
@torch.compile
68+
def compiled_fn(pred, v1, v2, c):
69+
return fn_with_cond(pred, v1, v2, c)
70+
71+
# Run with true condition
72+
compiled_fn(pred_true, value, value_alt, cache_true)
73+
74+
# Check that the true branch was executed (value was used)
75+
torch.testing.assert_close(
76+
cache_true[:, :2, :, :],
77+
value,
78+
msg="Cache not updated correctly in true branch",
79+
)
80+
81+
# Test with false condition
82+
pred_false = torch.tensor(False)
83+
cache_false = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
84+
85+
compiled_fn(pred_false, value, value_alt, cache_false)
86+
87+
# Check that the false branch was executed (value_alt was used)
88+
torch.testing.assert_close(
89+
cache_false[:, :2, :, :],
90+
value_alt,
91+
msg="Cache not updated correctly in false branch",
92+
)
93+
94+
def test_update_cross_attn_cache_export(self):
95+
96+
# Create tensors
97+
# Cache: [B=2, S_max=4, H=1, D=4]
98+
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
99+
# Value: [B=2, S=2, H=1, D=4]
100+
value = torch.randn(2, 2, 1, 4, dtype=torch.float32)
101+
102+
# Define a function that uses the op
103+
class UpdateCacheModule(torch.nn.Module):
104+
def forward(self, v, c):
105+
return torch.ops.executorch.update_cross_attn_cache(v, c)
106+
107+
module = UpdateCacheModule()
108+
109+
# Export the module
110+
exported_program = torch.export.export(
111+
module,
112+
(value, cache),
113+
)
114+
115+
# Run the exported program
116+
cache_exported = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
117+
exported_program.module()(value, cache_exported)
118+
119+
# Check correctness
120+
torch.testing.assert_close(
121+
cache_exported[:, :2, :, :],
122+
value,
123+
msg="Cache not updated correctly after export",
124+
)
125+
126+
def test_update_cross_attn_cache_different_shapes(self):
127+
print("Testing executorch::update_cross_attn_cache with different shapes...")
128+
129+
# Test with different batch sizes and sequence lengths
130+
test_cases = [
131+
# (B, S_max, S, H, D)
132+
(1, 10, 5, 2, 8),
133+
(4, 8, 3, 4, 16),
134+
(2, 16, 10, 1, 32),
135+
]
136+
137+
for B, S_max, S, H, D in test_cases:
138+
cache = torch.zeros(B, S_max, H, D, dtype=torch.float32)
139+
value = torch.randn(B, S, H, D, dtype=torch.float32)
140+
141+
@torch.compile
142+
def fn(v, c):
143+
return torch.ops.executorch.update_cross_attn_cache(v, c)
144+
145+
fn(value, cache)
146+
147+
# Check that the first S positions are updated
148+
torch.testing.assert_close(
149+
cache[:, :S, :, :],
150+
value,
151+
msg=f"Failed for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}",
152+
)
153+
154+
# Check that the rest remain zeros
155+
if S < S_max:
156+
torch.testing.assert_close(
157+
cache[:, S:, :, :],
158+
torch.zeros_like(cache[:, S:, :, :]),
159+
msg=f"Remaining cache modified for shape B={B}, S_max={S_max}, S={S}, H={H}, D={D}",
160+
)
161+
162+
def test_update_cross_attn_cache_full_sequence(self):
163+
164+
# Cache: [B=2, S_max=4, H=1, D=4]
165+
cache = torch.zeros(2, 4, 1, 4, dtype=torch.float32)
166+
# Value: [B=2, S=4, H=1, D=4] (S == S_max)
167+
value = torch.randn(2, 4, 1, 4, dtype=torch.float32)
168+
169+
@torch.compile
170+
def fn(v, c):
171+
return torch.ops.executorch.update_cross_attn_cache(v, c)
172+
173+
fn(value, cache)
174+
175+
# The entire cache should match value
176+
torch.testing.assert_close(
177+
cache, value, msg="Cache not fully updated when S == S_max"
178+
)

0 commit comments

Comments
 (0)