Skip to content

Conversation

@meta-emilian
Copy link
Contributor

@meta-emilian meta-emilian commented Aug 21, 2024

Summary:
This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

  • Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation

    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v

as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

  • Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel

Differential Revision: D61605316

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4822

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 9508947 with merge base 06c0fa3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 21, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

meta-emilian added a commit to meta-emilian/executorch that referenced this pull request Aug 21, 2024
Summary:
Pull Request resolved: pytorch#4822

This is part 1 of a multi-part commit to make torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this code needs to update the cache across the batch dimension.

Differential Revision: D61605316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

meta-emilian added a commit to meta-emilian/executorch that referenced this pull request Sep 16, 2024
Summary:
Pull Request resolved: pytorch#4822

This is part 1 of a multi-part commit to make torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this code needs to update the cache across the batch dimension.

Differential Revision: D61605316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

meta-emilian added a commit to meta-emilian/executorch that referenced this pull request Sep 17, 2024
Summary:
Pull Request resolved: pytorch#4822

This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

* Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

* Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel

Differential Revision: D61605316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

meta-emilian added a commit to meta-emilian/executorch that referenced this pull request Sep 17, 2024
Summary:
Pull Request resolved: pytorch#4822

This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

* Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

* Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel

Differential Revision: D61605316
@meta-emilian meta-emilian changed the title Making update_cache update across the batch dimension. [ExecuTorch] Batch-aware torch.ops.llama.sdpa_with_kv_cache. Sep 17, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

meta-emilian added a commit to meta-emilian/executorch that referenced this pull request Sep 17, 2024
Summary:
Pull Request resolved: pytorch#4822

This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

* Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

* Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel

Differential Revision: D61605316
Summary:
Pull Request resolved: pytorch#4822

This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

* Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

* Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel

Differential Revision: D61605316
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D61605316

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 53c1a5f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants