-
Notifications
You must be signed in to change notification settings - Fork 680
[DRAFT] Factor out core SDPA #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
fb870bc
Factor out core SDPA
jackzhxng e417acf
Kimish comment
jackzhxng e61cf56
Fixes
jackzhxng 46ccf12
Merge branch 'main' into jackxz/rewrite-attention-2
jackzhxng 2f914da
Fix bug
jackzhxng 2e10c51
Kimish comment
jackzhxng eea9b71
Improve docstring
jackzhxng f506e22
Kimish comment #2
jackzhxng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| from torch import nn, Tensor | ||
|
|
||
|
|
||
| class SDPA(nn.Module): | ||
| """ | ||
| The core of SDPA which can be optimized and can be swapped | ||
| out for a more efficient implementations. Split into | ||
| kv cache update and core sdpa (foward) components because | ||
| they are easier to optimize separately. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_kv_heads: int, | ||
| num_heads: int, | ||
| head_dim: int, | ||
| q_per_kv: int, | ||
| attn_dropout: float, | ||
| is_causal: bool, | ||
| attention_fn, | ||
| kv_cache, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.num_kv_heads = num_kv_heads | ||
| self.num_heads = num_heads | ||
| self.head_dim = head_dim | ||
| self.q_per_kv = q_per_kv | ||
| self.attn_dropout = attn_dropout | ||
| self.is_causal = is_causal | ||
| self._attention_fn = attention_fn | ||
| self._kv_cache = kv_cache | ||
|
|
||
| def kv_cache_update( | ||
| self, | ||
| input_pos: Tensor, | ||
| k: Tensor, | ||
| v: Tensor, | ||
| ) -> Tuple[Tensor, Tensor]: | ||
| k, v = self._kv_cache.update(input_pos, k, v) | ||
| return k, v | ||
|
|
||
| def forward( | ||
| self, | ||
| q: Tensor, # [b, s, n_h, h_d] | ||
| k: Tensor, # [b, s, n_kv, h_d] | ||
| v: Tensor, # [b, s, n_kv, h_d] | ||
| bsz: int, | ||
| seq_len: int, | ||
| mask: Tensor = None, | ||
| ) -> Tensor: | ||
| # View + expand + reshape bring num_kv_heads to num_heads for k and v | ||
| # to match q. | ||
|
|
||
| # k: [bsz, seq_len, n_kv, 1, h_d] | ||
| # v: [bsz, seq_len, n_kv, 1, h_d] | ||
| k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) | ||
| v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) | ||
|
|
||
| # Expand the key and value tensors to have the same shape | ||
| # as the query tensor by copying values across the relevant dim | ||
| if self.num_heads != self.num_kv_heads: | ||
| k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) | ||
| v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) | ||
|
|
||
| # [bsz, s, n_h, h_d] | ||
| k = k.reshape(bsz, seq_len, -1, self.head_dim) | ||
| v = v.reshape(bsz, seq_len, -1, self.head_dim) | ||
|
|
||
| # [bsz, n_h, s, h_d] | ||
| q = q.transpose(1, 2) | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
| output = self._attention_fn( | ||
| q, | ||
| k, | ||
| v, | ||
| mask=mask, | ||
| dropout_p=self.attn_dropout, | ||
| is_causal=self._kv_cache is None and mask is None and self.is_causal, | ||
| ) | ||
| # Reshape the output to be the same shape as the input | ||
| return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.