Skip to content

Conversation

@jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Jan 7, 2025

Summary

Use masking to avoid resorting to torch.cond, which prevents us from mutating inside branches thereby forcing us to clone the kv cache and create lots of unnecessary copies.

Also gets past the current limitation that the partitioners don't automatically recursively partition conditional subgraphs, allowing us to directly partition Llama 3.2 MM with XNNPack.

Llama 3.2 MM comparison against XNNPack + KV cache + custom SDPA

Metric Before After
Activations 0.95 GB 0.27 GB
.pte size 60 GB 30 GB
Prefill 20.91 s 0.86 s
Generation 0.28 tok/s 6.26 tok/s

Test plan

Rely on existing regression tests (test_attention and test_kv_cache - about to be merged), which have adequate coverage over kv cache and multi-head attention edge cases.

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7536

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 220aec8 with merge base ca32105 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 7, 2025
@github-actions
Copy link

github-actions bot commented Jan 7, 2025

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

k, v = self.kv_cache.update(k, v)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
output = self._sdpa(q, k, v, b, s_x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: the refactor removed the mask arg, make sure to add it back

@jackzhxng jackzhxng closed this Feb 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants