Skip to content

Commit 34614de

Browse files
committed
Factor out core SDPA
1 parent 75f6975 commit 34614de

File tree

2 files changed

+90
-41
lines changed

2 files changed

+90
-41
lines changed

torchtune/modules/attention.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import nn
1212
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1313
from torchtune.modules.kv_cache import KVCache
14+
from torchtune.modules.sdpa import SDPA
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -126,6 +127,8 @@ def __init__(
126127
self.head_dim = head_dim
127128
self.max_seq_len = max_seq_len
128129
self.is_causal = is_causal
130+
# Number of queries per k, v
131+
self.q_per_kv = self.num_heads // self.num_kv_heads
129132

130133
# Set layers
131134
self.kv_cache = kv_cache
@@ -139,6 +142,11 @@ def __init__(
139142

140143
# Use flex attention if supported and we are sample packing
141144
self._attention_call = _sdpa_or_flex_attention()
145+
self._sdpa = SDPA(
146+
attention_fn=self._attention_call,
147+
kv_cache=self.kv_cache,
148+
q_per_kv=self.q_per_kv,
149+
)
142150

143151
def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None:
144152
"""Setup key value caches for attention calculation. If called
@@ -228,18 +236,12 @@ def forward(
228236

229237
# q has shape [b, s_x, num_heads * head_dim]
230238
q = self.q_proj(x)
231-
232-
# number of queries per key/value
233-
q_per_kv = self.num_heads // self.num_kv_heads
234-
q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
239+
q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim)
235240

236241
# Apply positional embeddings
237242
if self.pos_embeddings is not None:
238243
q = self.pos_embeddings(q, input_pos=input_pos)
239244

240-
# [b, n_h, s_x, h_d]
241-
q = q.transpose(1, 2)
242-
243245
# Normalize q
244246
if self.q_norm is not None:
245247
q = self.q_norm(q)
@@ -262,48 +264,17 @@ def forward(
262264
# Apply positional embeddings
263265
# k: [b, s_y, n_kv, h_d]
264266
k = k.view(b, s_y, self.num_kv_heads, self.head_dim)
267+
v = v.view(b, s_y, self.num_kv_heads, self.head_dim)
265268
if self.pos_embeddings is not None:
266269
k = self.pos_embeddings(k, input_pos=input_pos)
267270

268-
# View + expand + reshape bring num_kv_heads to num_heads for k and v
269-
# to match q.
270-
271-
# k: [b, s_y, n_kv, 1, h_d]
272-
# v: [b, s_y, n_kv, 1, h_d]
273-
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
274-
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
275-
276-
# Expand the key and value tensors to have the same shape
277-
# as the query tensor by copying values across the relevant dim
278-
if self.num_heads != self.num_kv_heads:
279-
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
280-
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
281-
282-
# [b, s, n_h, h_d]
283-
k = k.reshape(b, s_y, -1, self.head_dim)
284-
v = v.reshape(b, s_y, -1, self.head_dim)
285-
286-
# [b, n_h, s, h_d]
287-
k = k.transpose(1, 2)
288-
v = v.transpose(1, 2)
289-
290271
# Normalize k
291272
if self.k_norm is not None:
292273
k = self.k_norm(k)
293274

294275
# Update key-value cache
295276
if self.kv_cache is not None:
296-
k, v = self.kv_cache.update(input_pos, k, v)
297-
298-
output = self._attention_call(
299-
q,
300-
k,
301-
v,
302-
mask=mask,
303-
dropout_p=self.attn_dropout,
304-
is_causal=self.kv_cache is None and mask is None and self.is_causal,
305-
)
277+
SDPA.kv_cache_update(input_pos, k, v)
306278

307-
# reshape the output to be the same shape as the input
308-
output = output.transpose(1, 2).contiguous().view(b, s_x, -1)
279+
output = SDPA.sdpa(q, k, v, b, s_x)
309280
return self.output_proj(output)

torchtune/modules/sdpa.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch import nn, Tensor
8+
9+
10+
class SDPA(nn.Module):
11+
"""
12+
The core of SDPA which can be optimized and can be swapped
13+
out for a more efficient implemmentations.
14+
15+
TODO: word the above docstring better.
16+
"""
17+
18+
def __init__(
19+
self,
20+
attention_fn,
21+
kv_cache,
22+
q_per_kv,
23+
) -> None:
24+
super().__init__()
25+
self._attention_fn = attention_fn
26+
self.kv_cache = kv_cache
27+
self.q_per_kv = q_per_kv
28+
29+
def kv_cache_update(
30+
self,
31+
input_pos: Tensor,
32+
k: Tensor,
33+
v: Tensor,
34+
) -> Tuple[Tensor, Tensor]:
35+
k, v = self.kv_cache.update(input_pos, k, v)
36+
return k, v
37+
38+
def sdpa(
39+
self,
40+
q: Tensor, # [b, s, n_h, h_d]
41+
k: Tensor, # [b, s, n_kv, h_d]
42+
v: Tensor, # [b, s, n_kv, h_d]
43+
bsz: int,
44+
seq_len: int,
45+
mask: Tensor = None,
46+
) -> Tensor:
47+
# View + expand + reshape bring num_kv_heads to num_heads for k and v
48+
# to match q.
49+
50+
# k: [bsz, seq_len, n_kv, 1, h_d]
51+
# v: [bsz, seq_len, n_kv, 1, h_d]
52+
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
53+
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
54+
55+
# Expand the key and value tensors to have the same shape
56+
# as the query tensor by copying values across the relevant dim
57+
if self.num_heads != self.num_kv_heads:
58+
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
59+
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
60+
61+
# [bsz, s, n_h, h_d]
62+
k = k.reshape(bsz, seq_len, -1, self.head_dim)
63+
v = v.reshape(bsz, seq_len, -1, self.head_dim)
64+
65+
# [bsz, n_h, s, h_d]
66+
q = q.transpose(1, 2)
67+
k = k.transpose(1, 2)
68+
v = v.transpose(1, 2)
69+
output = self._attention_fn(
70+
q,
71+
k,
72+
v,
73+
mask=mask,
74+
dropout_p=self.attn_dropout,
75+
is_causal=self.kv_cache is None and mask is None and self.is_causal,
76+
)
77+
# Reshape the output to be the same shape as the input
78+
return output.transpose(1, 2).contiguous().view(b, s_x, -1)

0 commit comments

Comments
 (0)