diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 2dfeaddc9a..fe8a158a4c 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -11,6 +11,7 @@ from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache +from torchtune.modules.sdpa import SDPA logger = logging.getLogger(__name__) @@ -126,6 +127,8 @@ def __init__( self.head_dim = head_dim self.max_seq_len = max_seq_len self.is_causal = is_causal + # Number of queries per k, v + self.q_per_kv = self.num_heads // self.num_kv_heads # Set layers self.kv_cache = kv_cache @@ -139,6 +142,16 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() + self._sdpa = SDPA( + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + head_dim=self.head_dim, + q_per_kv=self.q_per_kv, + attn_dropout=self.attn_dropout, + is_causal=self.is_causal, + attention_fn=self._attention_call, + kv_cache=self.kv_cache, + ) def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int @@ -227,18 +240,12 @@ def forward( # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) - - # number of queries per key/value - q_per_kv = self.num_heads // self.num_kv_heads - q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim) # Apply positional embeddings if self.pos_embeddings is not None: q = self.pos_embeddings(q, input_pos=input_pos) - # [b, n_h, s_x, h_d] - q = q.transpose(1, 2) - # Normalize q if self.q_norm is not None: q = self.q_norm(q) @@ -261,48 +268,17 @@ def forward( # Apply positional embeddings # k: [b, s_y, n_kv, h_d] k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. - - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - - # If needed, expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - # Normalize k if self.k_norm is not None: k = self.k_norm(k) # Update key-value cache if self.kv_cache is not None: - k, v = self.kv_cache.update(k, v) - - output = self._attention_call( - q, - k, - v, - mask=mask, - dropout_p=self.attn_dropout, - is_causal=self.kv_cache is None and mask is None and self.is_causal, - ) + self._sdpa.kv_cache_update(input_pos, k, v) - # reshape the output to be the same shape as the input - output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + output = self._sdpa(q, k, v, b, s_x) return self.output_proj(output) diff --git a/torchtune/modules/sdpa.py b/torchtune/modules/sdpa.py new file mode 100644 index 0000000000..3f0bb324a6 --- /dev/null +++ b/torchtune/modules/sdpa.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +from torch import nn, Tensor + + +class SDPA(nn.Module): + """ + The core of SDPA which can be optimized and can be swapped + out for a more efficient implementations. Split into + kv cache update and core sdpa (foward) components because + they are easier to optimize separately. + """ + + def __init__( + self, + num_kv_heads: int, + num_heads: int, + head_dim: int, + q_per_kv: int, + attn_dropout: float, + is_causal: bool, + attention_fn, + kv_cache, + ) -> None: + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.q_per_kv = q_per_kv + self.attn_dropout = attn_dropout + self.is_causal = is_causal + self._attention_fn = attention_fn + self._kv_cache = kv_cache + + def kv_cache_update( + self, + input_pos: Tensor, + k: Tensor, + v: Tensor, + ) -> Tuple[Tensor, Tensor]: + k, v = self._kv_cache.update(input_pos, k, v) + return k, v + + def forward( + self, + q: Tensor, # [b, s, n_h, h_d] + k: Tensor, # [b, s, n_kv, h_d] + v: Tensor, # [b, s, n_kv, h_d] + bsz: int, + seq_len: int, + mask: Tensor = None, + ) -> Tensor: + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim) + + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, seq_len, -1, self.head_dim) + v = v.reshape(bsz, seq_len, -1, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + output = self._attention_fn( + q, + k, + v, + mask=mask, + dropout_p=self.attn_dropout, + is_causal=self._kv_cache is None and mask is None and self.is_causal, + ) + # Reshape the output to be the same shape as the input + return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)