Skip to content

Conversation

@larryliu0820
Copy link
Contributor

@larryliu0820 larryliu0820 commented Nov 21, 2025

torch.cond doesn't take aliasing or mutations. Adding 2 ops for supporting conditionally updating kv cache:

  • executorch::alias: takes 2 tensors and return the same 2 tensors.
  • executorch::update_cross_attn_cache: takes a tensor cache and a tensor value, in place copy value into cache.

With these 2 ops, we can rewrite the model definition from:

if is_cross_attention and past_key_values and is_updated:
    # reuse k,v, cross_attentions
    key_states = past_key_values.layers[self.layer_idx].keys
    value_states = past_key_values.layers[self.layer_idx].values
else:
    key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
    value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
    key_states = key_states.transpose(1, 2).contiguous()
    value_states = value_states.transpose(1, 2).contiguous()
    if past_key_values is not None:
        # save all key/value_states to cache to be re-used for fast auto-regressive generation
        cache_position = cache_position if not is_cross_attention else None
        key_states, value_states = past_key_values.update(
            key_states, value_states, self.layer_idx, {"cache_position": cache_position}
        )

Into:

def use_cached_kv(
    cached_keys: Tensor,
    cached_values: Tensor,
    key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
    # Just reuse cached K/V
    return torch.ops.executorch.alias(cached_keys, cached_values)

def recompute_kv(
    cached_keys: Tensor,  # unused
    cached_values: Tensor,  # unused
    key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
    # Compute fresh K/V (export-friendly: use custom op to mutate cache)
    key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
    value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
    key_states = key_states.transpose(1, 2).contiguous()
    value_states = value_states.transpose(1, 2).contiguous()
    k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys)
    v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values)
    return k, v

if past_key_values is not None and self.layer_idx is not None:
    # Grab cached tensors (these are Tensors, so they are OK for export)
    cached_keys = past_key_values.layers[self.layer_idx].keys
    cached_values = past_key_values.layers[self.layer_idx].values

    # Tensor predicate: True if any element is non-zero
    # Result is a 0-dim bool tensor suitable for torch.cond
    cache_is_initialized = (cached_keys != 0).any()

    # Use torch.cond to select branch in a traceable way.
    # All operands must be (nested) tensors or simple Python values.
    key_states, value_states = torch.cond(
        cache_is_initialized,
        use_cached_kv,
        recompute_kv,
        operands=(cached_keys, cached_values, key_value_states),
    )

Summary

[PLEASE REMOVE] See CONTRIBUTING.md's Pull Requests for ExecuTorch PR guidelines.

[PLEASE REMOVE] If this PR closes an issue, please add a Fixes #<issue-id> line.

[PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out CONTRIBUTING.md's Pull Requests.

Test plan

[PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15937

Note: Links to docs will display an error until the docs builds have been completed.

❌ 20 New Failures

As of commit 00e89d7 with merge base fee1b2d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 21, 2025
@larryliu0820 larryliu0820 added the release notes: desktop for desktop/laptop workstream label Nov 21, 2025
torch.cond doesn't take aliasing or mutations. Adding 2 ops for
supporting conditionally updating kv cache:

* `executorch::alias`: takes 2 tensors and return the same 2 tensors.
* `executorch::cross_attn_cache_update`: takes a tensor `cache` and a
  tensor `value`, in place copy `value` into `cache`.

With these 2 ops, we can rewrite the model definition from:

```py
if is_cross_attention and past_key_values and is_updated:
    # reuse k,v, cross_attentions
    key_states = past_key_values.layers[self.layer_idx].keys
    value_states = past_key_values.layers[self.layer_idx].values
else:
    key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
    value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
    key_states = key_states.transpose(1, 2).contiguous()
    value_states = value_states.transpose(1, 2).contiguous()
    if past_key_values is not None:
        # save all key/value_states to cache to be re-used for fast auto-regressive generation
        cache_position = cache_position if not is_cross_attention else None
        key_states, value_states = past_key_values.update(
            key_states, value_states, self.layer_idx, {"cache_position": cache_position}
        )
```

Into:

```py
def use_cached_kv(
    cached_keys: Tensor,
    cached_values: Tensor,
    key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
    # Just reuse cached K/V
    return torch.ops.executorch.alias(cached_keys, cached_values)

def recompute_kv(
    cached_keys: Tensor,  # unused
    cached_values: Tensor,  # unused
    key_value_states: Tensor,
) -> tuple[Tensor, Tensor]:
    # Compute fresh K/V (export-friendly: no cache mutation in here)
    key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
    value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
    key_states = key_states.transpose(1, 2).contiguous()
    value_states = value_states.transpose(1, 2).contiguous()
    k = torch.ops.executorch.update_cross_attn_cache(key_states, cached_keys)
    v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values)
    return k, v

if past_key_values is not None and self.layer_idx is not None:
    # Grab cached tensors (these are Tensors, so they are OK for export)
    cached_keys = past_key_values.layers[self.layer_idx].keys
    cached_values = past_key_values.layers[self.layer_idx].values

    # Tensor predicate: True if any element is non-zero
    # Result is a 0-dim bool tensor suitable for torch.cond
    cache_is_initialized = (cached_keys != 0).any()

    # Use torch.cond to select branch in a traceable way.
    # All operands must be (nested) tensors or simple Python values.
    key_states, value_states = torch.cond(
        cache_is_initialized,
        use_cached_kv,
        recompute_kv,
        operands=(cached_keys, cached_values, key_value_states),
    )
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: desktop for desktop/laptop workstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants