Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 17 additions & 41 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
from torchtune.modules.sdpa import SDPA

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,6 +127,8 @@ def __init__(
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.is_causal = is_causal
# Number of queries per k, v
self.q_per_kv = self.num_heads // self.num_kv_heads

# Set layers
self.kv_cache = kv_cache
Expand All @@ -139,6 +142,16 @@ def __init__(

# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()
self._sdpa = SDPA(
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
q_per_kv=self.q_per_kv,
attn_dropout=self.attn_dropout,
is_causal=self.is_causal,
attention_fn=self._attention_call,
kv_cache=self.kv_cache,
)

def setup_cache(
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
Expand Down Expand Up @@ -227,18 +240,12 @@ def forward(

# q has shape [b, s_x, num_heads * head_dim]
q = self.q_proj(x)

# number of queries per key/value
q_per_kv = self.num_heads // self.num_kv_heads
q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim)

# Apply positional embeddings
if self.pos_embeddings is not None:
q = self.pos_embeddings(q, input_pos=input_pos)

# [b, n_h, s_x, h_d]
q = q.transpose(1, 2)

# Normalize q
if self.q_norm is not None:
q = self.q_norm(q)
Expand All @@ -261,48 +268,17 @@ def forward(
# Apply positional embeddings
# k: [b, s_y, n_kv, h_d]
k = k.view(b, s_y, -1, self.head_dim)
v = v.view(b, s_y, -1, self.head_dim)
if self.pos_embeddings is not None:
k = self.pos_embeddings(k, input_pos=input_pos)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# k: [b, s_y, n_kv, 1, h_d]
# v: [b, s_y, n_kv, 1, h_d]
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)

# [b, s, n_h, h_d]
k = k.reshape(b, s_y, -1, self.head_dim)
v = v.reshape(b, s_y, -1, self.head_dim)

# [b, n_h, s, h_d]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
k, v = self.kv_cache.update(k, v)

output = self._attention_call(
q,
k,
v,
mask=mask,
dropout_p=self.attn_dropout,
is_causal=self.kv_cache is None and mask is None and self.is_causal,
)
self._sdpa.kv_cache_update(input_pos, k, v)

# reshape the output to be the same shape as the input
output = output.transpose(1, 2).contiguous().view(b, s_x, -1)
output = self._sdpa.forward(q, k, v, b, s_x)
return self.output_proj(output)
90 changes: 90 additions & 0 deletions torchtune/modules/sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

from torch import nn, Tensor


class SDPA(nn.Module):
"""
The core of SDPA which can be optimized and can be swapped
out for a more efficient implementations. Split into
kv cache update and core sdpa (foward) components because
they are easier to optimize separately.
"""

def __init__(
self,
num_kv_heads: int,
num_heads: int,
head_dim: int,
q_per_kv: int,
attn_dropout: float,
is_causal: bool,
attention_fn,
kv_cache,
) -> None:
super().__init__()
self.num_kv_heads = num_kv_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.q_per_kv = q_per_kv
self.attn_dropout = attn_dropout
self.is_causal = is_causal
self._attention_fn = attention_fn
self._kv_cache = kv_cache

def kv_cache_update(
self,
input_pos: Tensor,
k: Tensor,
v: Tensor,
) -> Tuple[Tensor, Tensor]:
k, v = self._kv_cache.update(input_pos, k, v)
return k, v

def forward(
self,
q: Tensor, # [b, s, n_h, h_d]
k: Tensor, # [b, s, n_kv, h_d]
v: Tensor, # [b, s, n_kv, h_d]
bsz: int,
seq_len: int,
mask: Tensor = None,
) -> Tensor:
# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# k: [bsz, seq_len, n_kv, 1, h_d]
# v: [bsz, seq_len, n_kv, 1, h_d]
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)

# Expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
k = k.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)
v = v.expand(bsz, seq_len, self.num_kv_heads, self.q_per_kv, self.head_dim)

# [bsz, s, n_h, h_d]
k = k.reshape(bsz, seq_len, -1, self.head_dim)
v = v.reshape(bsz, seq_len, -1, self.head_dim)

# [bsz, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output = self._attention_fn(
q,
k,
v,
mask=mask,
dropout_p=self.attn_dropout,
is_causal=self._kv_cache is None and mask is None and self.is_causal,
)
# Reshape the output to be the same shape as the input
return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)