Skip to content

Conversation

@jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Sep 12, 2024

Context

This PR factors out the optimizable portions of SDPA (namely the kv cache update, the transpose, the expand, and the actual sdpa). This allows a module containing optimized implementations of the above functionalities to easily be swapped in with the new module via source transformation.

Atm, ET has an optimized SDPA op that does all of the above (kv cache update, transpose, expand, sdpa) that we are hoping to swap in pre-export. cc @kimishpatel.

Proof of correctness

Before change

$ tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device max_steps_per_epoch=25  epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
.
.
.
1|25|Loss: 1.487076759338379: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [11:27<00:00, 27.49s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▆▆▆▇▇▅▇▅▅█▅▆▆▅▅▄▃▃▃▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▂██████████████████████
wandb: tokens_per_second_per_gpu ▃▄▃▃▄▃▂▂▆▄▅█▁▅▂▂▂▄▄▃▃▄▃▂▃
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48708
wandb:                        lr 7e-05
wandb:        peak_memory_active 16.20987
wandb:         peak_memory_alloc 16.20987
wandb:      peak_memory_reserved 18.77148
wandb: tokens_per_second_per_gpu 1313.99555
wandb: 
wandb: 🚀 View run flowing-lion-18 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/imx4iros
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240920_154125-imx4iros/logs

After change

We see essentially the exact same loss (slight difference since only trained for one epoch). Interestingly, we also squeeze out ~9% more tokens per gpu after this refactor.

# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device max_steps_per_epoch=25  epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True
.
.
.
1|25|Loss: 1.480817198753357: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [11:25<00:00, 27.41s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▇▆▆▇▇▅▇▅▅█▅▆▆▅▅▄▃▃▂▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▂██████████████████████
wandb: tokens_per_second_per_gpu ▄▆▆▇▇▃▃▁▆▂▅█▂▄▃▅▃▃▃▃▄▄▃▃▆
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48082
wandb:                        lr 7e-05
wandb:        peak_memory_active 16.20987
wandb:         peak_memory_alloc 16.20987
wandb:      peak_memory_reserved 18.77148
wandb: tokens_per_second_per_gpu 1434.81084
wandb: 
wandb: 🚀 View run fresh-puddle-17 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/ehh7osgb
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240920_152752-ehh7osgb/logs

Changelog

Factor out inference-optimizable portions of SDPA

Test plan

[Pending] Went through a quick export process as sanity check, will do more extensive correctness checking

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

No public API changes

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1561

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 5 Cancelled Jobs

As of commit f506e22 with merge base 9a863c8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@ebsmothers
Copy link
Contributor

Hey @dvorjackz can you share the motivation for this PR? (I know it's still a draft, just wanna understand what the goal is for when I do eventually review it)

@jackzhxng
Copy link
Contributor Author

@ebsmothers added context to the pr description!

Copy link

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments

@jackzhxng jackzhxng force-pushed the jackxz/rewrite-attention-2 branch from acffd0a to e61cf56 Compare September 18, 2024 02:07
@kimishpatel
Copy link

left another nit. Looks good to me. Make sure it is exportable. And want to hear thoughts from Tune folks

@felipemello1
Copy link
Contributor

hey @dvorjackz , thanks for the PR and the extra context! It makes sense to me to make it swappable, but this is a relatively large change. I am not sure how this will interact with compile + multimodal. There are also other ongoing PRs that are modifying kv_cache. We may need to align on the design a bit.

I want to minimize the amount of work you have to do, but if your version is working, adding testing will make it much easier to approve. (e.g. you could run supervised training for vision model for 50 steps with / without the PR, to compare that everything works fine, and then run a generation task.) But we should align on the design first.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in getting to this one. Tbh I am not sure I like the way we are refactoring the multi-head attention here. I think there is a pretty canonical MHA flow that folks in OSS are used to (see e.g. litgpt's CausalSelfAttention or transformers's LlamaAttention) and this would be diverging from that in a meaningful way. I am OK with such divergence if it makes the code easier to understand, but in this case we are actually adding another layer of abstraction that's not very clear (why do we need a separate module to handle a couple reshapes + SDPA? why do we call it SDPA when it in fact contains a call to nn.functional.scaled_dot_product_attention, which we then call self._attn_fn?)

Anyways this is not to be too harsh on this PR cause I do understand the motivation from the ET perspective. Just so that I am more well-informed here, can you share the optimized SDPA op from ET so I can look at it as a reference? Then maybe we can brainstorm about a less intrusive way we can achieve this.

@kimishpatel
Copy link

Sorry for the delay in getting to this one. Tbh I am not sure I like the way we are refactoring the multi-head attention here. I think there is a pretty canonical MHA flow that folks in OSS are used to (see e.g. litgpt's CausalSelfAttention or transformers's LlamaAttention) and this would be diverging from that in a meaningful way. I am OK with such divergence if it makes the code easier to understand, but in this case we are actually adding another layer of abstraction that's not very clear (why do we need a separate module to handle a couple reshapes + SDPA? why do we call it SDPA when it in fact contains a call to nn.functional.scaled_dot_product_attention, which we then call self._attn_fn?)

Anyways this is not to be too harsh on this PR cause I do understand the motivation from the ET perspective. Just so that I am more well-informed here, can you share the optimized SDPA op from ET so I can look at it as a reference? Then maybe we can brainstorm about a less intrusive way we can achieve this.

Evan,
Your point is quite fair. See the custom SDPA implementation here. https://github.com/pytorch/executorch/blob/main/examples/models/llama2/source_transformation/sdpa.py#L19. Really at the heart of it is SDPA with kv cache. So the point of the PR is to refactor SDPA to own the cache. Crux of the issue is that when the model is exported, updates from the kv cache result in ton of copies that causes issues for both memory footprint and latency. The refactored SDPA allows ET to apply ET specific module swap that does kv cache update via custom op.

But I would not use that as an argument to necessitate this change. This change should also stand on its own merit. To me decoupling attention from q, k, v projections plus rope is good. For example if I were to apply sliding window attention like the one in here https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L274, I will have to make more intrusive changes. If I were to have ring buffer kv cache, then also I will have to update MHA implementation. So to me decoupling projection + rope from SDPA helps. In this case we add kv cache to SDPA as well because it is really relevant to attention.

@ebsmothers
Copy link
Contributor

Thanks @kimishpatel for the explanation and the code pointer, this helps a lot. I have a handful of follow-up comments:

(1) Sliding window attention is an interesting example. I had been thinking recently about this, but from the perspective of using FlexAttention to enable it (instead of what they're doing in litgpt). I saw there is an SDPAFlex in ET, I'm curious how you're currently using that.

(2) Kind of related to (1).. I realized my claim of "we shouldn't add layers of abstraction to MHA" wasn't 100% honest since we've already kinda done it (sorry about that). We do already have SDPA parametrized for flex attention support. I think the thing I don't like is that in this PR we are keeping the existing parametrization of FlexAttention vs functional.scaled_dot_product_attention and then wrapping that in another abstraction (and for some reason putting all the reshapes in there too). It seems to me like the SDPA class defined in this PR should (a) take in q, k, v in their SDPA-ready shapes, (b) do the KV caching, (c) check if the mask is a block mask or not, and (d) dispatch to either our compiled flex attention or vanilla functional.scaled_dot_product_attention based on (c). There may be some subtleties here around compiling FlexAttention that I'm missing, cc @RdoubleA for a sanity check there.

(3) I do agree with @felipemello1's previous comment about some of our ongoing KV cache refactors (also about interactions with other features, but that can come after imo). One comment I have up front is that it's a bit weird to make SDPA stateful (but I'm just biased and like to keep everything functional). Ultimately I care more about correctness: @SalmanMohammadi and @joecummings have been doing a lot here recently, so I wanna have them take a look to be sure this makes sense from their perspective.

Sorry I know there's a lot here, but TLDR is I generally agree with the idea of factoring out the actual SDPA op provided we can do it cleanly and consistently. I think there are a few changes to make in this PR to get there, but happy to get more into the weeds with the design than I have here if that'd help. Other than that the main open question I have is around the KV cache and ensuring we don't break anything there.

@kimishpatel
Copy link

@ebsmothers

I saw there is an SDPAFlex in ET, I'm curious how you're currently using that.

Ignore that. I dont know why we named it such. Not really a flex attention and not a good example to focus

It seems to me like the SDPA class defined in this PR should (a) take in q, k, v in their SDPA-ready shapes, (b) do the KV caching, (c) check if the mask is a block mask or not, and (d) dispatch to either our compiled flex attention or vanilla functional.scaled_dot_product_attention based on (c)

This does make sense. I need to think a bit on how to make that work with SDPA that we have which is really why all the reshape and tranposes have moved inside, but your point is fair. Let me think about it.

One comment I have up front is that it's a bit weird to make SDPA stateful

Yeah thats true and this part is probably very specific to what ET wants. Let me actually spend some time on this to iron out details

@jackzhxng
Copy link
Contributor Author

jackzhxng commented Nov 13, 2024

Hi folks, shelving this for now since we chose to just add a clone of MultiHeadAttention into ET to source transform in pytorch/executorch#6719

@jackzhxng jackzhxng closed this Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants