-
Notifications
You must be signed in to change notification settings - Fork 712
[ExecuTorch] Batch-aware torch.ops.llama.sdpa_with_kv_cache. #4822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4822
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 9508947 with merge base 06c0fa3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
Summary: Pull Request resolved: pytorch#4822 This is part 1 of a multi-part commit to make torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this code needs to update the cache across the batch dimension. Differential Revision: D61605316
0f8c06a to
54bdb58
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
Summary: Pull Request resolved: pytorch#4822 This is part 1 of a multi-part commit to make torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this code needs to update the cache across the batch dimension. Differential Revision: D61605316
54bdb58 to
ae0964a
Compare
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
ae0964a to
1e0568b
Compare
Summary: Pull Request resolved: pytorch#4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
1e0568b to
a9dae49
Compare
Summary: Pull Request resolved: pytorch#4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
a9dae49 to
377e496
Compare
Summary: Pull Request resolved: pytorch#4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316
Summary: Pull Request resolved: pytorch#4822 This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search. * Makes update_cache update across the batch dimension As a performance optimization, update_cache implements the following operation ``` k_cache[:, start_pos : start_pos + seq_len, :, :] = k v_cache[:, start_pos : start_pos + seq_len, :, :] = v ``` as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1. * Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well. ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true. Reviewed By: kimishpatel Differential Revision: D61605316
|
This pull request was exported from Phabricator. Differential Revision: D61605316 |
377e496 to
9508947
Compare
|
This pull request has been merged in 53c1a5f. |
Summary:
This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.
As a performance optimization, update_cache implements the following operation
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.
ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.
Reviewed By: kimishpatel
Differential Revision: D61605316