-
Notifications
You must be signed in to change notification settings - Fork 680
[DRAFT] Factor out core SDPA #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1561
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 5 Cancelled JobsAs of commit f506e22 with merge base 9a863c8 ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hey @dvorjackz can you share the motivation for this PR? (I know it's still a draft, just wanna understand what the goal is for when I do eventually review it) |
|
@ebsmothers added context to the pr description! |
kimishpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments
acffd0a to
e61cf56
Compare
|
left another nit. Looks good to me. Make sure it is exportable. And want to hear thoughts from Tune folks |
|
hey @dvorjackz , thanks for the PR and the extra context! It makes sense to me to make it swappable, but this is a relatively large change. I am not sure how this will interact with compile + multimodal. There are also other ongoing PRs that are modifying kv_cache. We may need to align on the design a bit. I want to minimize the amount of work you have to do, but if your version is working, adding testing will make it much easier to approve. (e.g. you could run supervised training for vision model for 50 steps with / without the PR, to compare that everything works fine, and then run a generation task.) But we should align on the design first. |
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in getting to this one. Tbh I am not sure I like the way we are refactoring the multi-head attention here. I think there is a pretty canonical MHA flow that folks in OSS are used to (see e.g. litgpt's CausalSelfAttention or transformers's LlamaAttention) and this would be diverging from that in a meaningful way. I am OK with such divergence if it makes the code easier to understand, but in this case we are actually adding another layer of abstraction that's not very clear (why do we need a separate module to handle a couple reshapes + SDPA? why do we call it SDPA when it in fact contains a call to nn.functional.scaled_dot_product_attention, which we then call self._attn_fn?)
Anyways this is not to be too harsh on this PR cause I do understand the motivation from the ET perspective. Just so that I am more well-informed here, can you share the optimized SDPA op from ET so I can look at it as a reference? Then maybe we can brainstorm about a less intrusive way we can achieve this.
Evan, But I would not use that as an argument to necessitate this change. This change should also stand on its own merit. To me decoupling attention from q, k, v projections plus rope is good. For example if I were to apply sliding window attention like the one in here https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L274, I will have to make more intrusive changes. If I were to have ring buffer kv cache, then also I will have to update MHA implementation. So to me decoupling projection + rope from SDPA helps. In this case we add kv cache to SDPA as well because it is really relevant to attention. |
|
Thanks @kimishpatel for the explanation and the code pointer, this helps a lot. I have a handful of follow-up comments: (1) Sliding window attention is an interesting example. I had been thinking recently about this, but from the perspective of using (2) Kind of related to (1).. I realized my claim of "we shouldn't add layers of abstraction to MHA" wasn't 100% honest since we've already kinda done it (sorry about that). We do already have SDPA parametrized for flex attention support. I think the thing I don't like is that in this PR we are keeping the existing parametrization of (3) I do agree with @felipemello1's previous comment about some of our ongoing KV cache refactors (also about interactions with other features, but that can come after imo). One comment I have up front is that it's a bit weird to make SDPA stateful (but I'm just biased and like to keep everything functional). Ultimately I care more about correctness: @SalmanMohammadi and @joecummings have been doing a lot here recently, so I wanna have them take a look to be sure this makes sense from their perspective. Sorry I know there's a lot here, but TLDR is I generally agree with the idea of factoring out the actual SDPA op provided we can do it cleanly and consistently. I think there are a few changes to make in this PR to get there, but happy to get more into the weeds with the design than I have here if that'd help. Other than that the main open question I have is around the KV cache and ensuring we don't break anything there. |
Ignore that. I dont know why we named it such. Not really a flex attention and not a good example to focus
This does make sense. I need to think a bit on how to make that work with SDPA that we have which is really why all the reshape and tranposes have moved inside, but your point is fair. Let me think about it.
Yeah thats true and this part is probably very specific to what ET wants. Let me actually spend some time on this to iron out details |
|
Hi folks, shelving this for now since we chose to just add a clone of MultiHeadAttention into ET to source transform in pytorch/executorch#6719 |
Context
This PR factors out the optimizable portions of SDPA (namely the kv cache update, the transpose, the expand, and the actual sdpa). This allows a module containing optimized implementations of the above functionalities to easily be swapped in with the new module via source transformation.
Atm, ET has an optimized SDPA op that does all of the above (kv cache update, transpose, expand, sdpa) that we are hoping to swap in pre-export. cc @kimishpatel.
Proof of correctness
Before change
After change
We see essentially the exact same loss (slight difference since only trained for one epoch). Interestingly, we also squeeze out ~9% more tokens per gpu after this refactor.
Changelog
Factor out inference-optimizable portions of SDPA
Test plan
[Pending] Went through a quick export process as sanity check, will do more extensive correctness checking
pre-commit install)pytest testspytest tests -m integration_testUX
No public API changes