Skip to content

Commit fb870bc

Browse files
committed
Factor out core SDPA
1 parent c5db813 commit fb870bc

File tree

2 files changed

+90
-41
lines changed

2 files changed

+90
-41
lines changed

torchtune/modules/attention.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch import nn
1212
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1313
from torchtune.modules.kv_cache import KVCache
14+
from torchtune.modules.sdpa import SDPA
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -126,6 +127,8 @@ def __init__(
126127
self.head_dim = head_dim
127128
self.max_seq_len = max_seq_len
128129
self.is_causal = is_causal
130+
# Number of queries per k, v
131+
self.q_per_kv = self.num_heads // self.num_kv_heads
129132

130133
# Set layers
131134
self.kv_cache = kv_cache
@@ -139,6 +142,11 @@ def __init__(
139142

140143
# Use flex attention if supported and we are sample packing
141144
self._attention_call = _sdpa_or_flex_attention()
145+
self._sdpa = SDPA(
146+
attention_fn=self._attention_call,
147+
kv_cache=self.kv_cache,
148+
q_per_kv=self.q_per_kv,
149+
)
142150

143151
def setup_cache(
144152
self, batch_size: int, dtype: torch.dtype, max_seq_len: int
@@ -227,18 +235,12 @@ def forward(
227235

228236
# q has shape [b, s_x, num_heads * head_dim]
229237
q = self.q_proj(x)
230-
231-
# number of queries per key/value
232-
q_per_kv = self.num_heads // self.num_kv_heads
233-
q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
238+
q = q.view(b, s_x, self.num_kv_heads * self.q_per_kv, self.head_dim)
234239

235240
# Apply positional embeddings
236241
if self.pos_embeddings is not None:
237242
q = self.pos_embeddings(q, input_pos=input_pos)
238243

239-
# [b, n_h, s_x, h_d]
240-
q = q.transpose(1, 2)
241-
242244
# Normalize q
243245
if self.q_norm is not None:
244246
q = self.q_norm(q)
@@ -261,48 +263,17 @@ def forward(
261263
# Apply positional embeddings
262264
# k: [b, s_y, n_kv, h_d]
263265
k = k.view(b, s_y, -1, self.head_dim)
266+
v = v.view(b, s_y, -1, self.head_dim)
264267
if self.pos_embeddings is not None:
265268
k = self.pos_embeddings(k, input_pos=input_pos)
266269

267-
# View + expand + reshape bring num_kv_heads to num_heads for k and v
268-
# to match q.
269-
270-
# k: [b, s_y, n_kv, 1, h_d]
271-
# v: [b, s_y, n_kv, 1, h_d]
272-
k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
273-
v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim)
274-
275-
# If needed, expand the key and value tensors to have the same shape
276-
# as the query tensor by copying values across the relevant dim
277-
if self.num_heads != self.num_kv_heads:
278-
k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
279-
v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim)
280-
281-
# [b, s, n_h, h_d]
282-
k = k.reshape(b, s_y, -1, self.head_dim)
283-
v = v.reshape(b, s_y, -1, self.head_dim)
284-
285-
# [b, n_h, s, h_d]
286-
k = k.transpose(1, 2)
287-
v = v.transpose(1, 2)
288-
289270
# Normalize k
290271
if self.k_norm is not None:
291272
k = self.k_norm(k)
292273

293274
# Update key-value cache
294275
if self.kv_cache is not None:
295-
k, v = self.kv_cache.update(k, v)
296-
297-
output = self._attention_call(
298-
q,
299-
k,
300-
v,
301-
mask=mask,
302-
dropout_p=self.attn_dropout,
303-
is_causal=self.kv_cache is None and mask is None and self.is_causal,
304-
)
276+
SDPA.kv_cache_update(input_pos, k, v)
305277

306-
# reshape the output to be the same shape as the input
307-
output = output.transpose(1, 2).contiguous().view(b, s_x, -1)
278+
output = SDPA.sdpa(q, k, v, b, s_x)
308279
return self.output_proj(output)

torchtune/modules/sdpa.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch import nn, Tensor
8+
9+
10+
class SDPA(nn.Module):
11+
"""
12+
The core of SDPA which can be optimized and can be swapped
13+
out for a more efficient implemmentations.
14+
15+
TODO: word the above docstring better.
16+
"""
17+
18+
def __init__(
19+
self,
20+
attention_fn,
21+
kv_cache,
22+
q_per_kv,
23+
) -> None:
24+
super().__init__()
25+
self._attention_fn = attention_fn
26+
self.kv_cache = kv_cache
27+
self.q_per_kv = q_per_kv
28+
29+
def kv_cache_update(
30+
self,
31+
input_pos: Tensor,
32+
k: Tensor,
33+
v: Tensor,
34+
) -> Tuple[Tensor, Tensor]:
35+
k, v = self.kv_cache.update(input_pos, k, v)
36+
return k, v
37+
38+
def sdpa(
39+
self,
40+
q: Tensor, # [b, s, n_h, h_d]
41+
k: Tensor, # [b, s, n_kv, h_d]
42+
v: Tensor, # [b, s, n_kv, h_d]
43+
bsz: int,
44+
seq_len: int,
45+
mask: Tensor = None,
46+
) -> Tensor:
47+
# View + expand + reshape bring num_kv_heads to num_heads for k and v
48+
# to match q.
49+
50+
# k: [bsz, seq_len, n_kv, 1, h_d]
51+
# v: [bsz, seq_len, n_kv, 1, h_d]
52+
k = k.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
53+
v = v.view(bsz, seq_len, self.num_kv_heads, 1, self.head_dim)
54+
55+
# Expand the key and value tensors to have the same shape
56+
# as the query tensor by copying values across the relevant dim
57+
if self.num_heads != self.num_kv_heads:
58+
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
59+
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
60+
61+
# [bsz, s, n_h, h_d]
62+
k = k.reshape(bsz, seq_len, -1, self.head_dim)
63+
v = v.reshape(bsz, seq_len, -1, self.head_dim)
64+
65+
# [bsz, n_h, s, h_d]
66+
q = q.transpose(1, 2)
67+
k = k.transpose(1, 2)
68+
v = v.transpose(1, 2)
69+
output = self._attention_fn(
70+
q,
71+
k,
72+
v,
73+
mask=mask,
74+
dropout_p=self.attn_dropout,
75+
is_causal=self.kv_cache is None and mask is None and self.is_causal,
76+
)
77+
# Reshape the output to be the same shape as the input
78+
return output.transpose(1, 2).contiguous().view(b, s_x, -1)

0 commit comments

Comments
 (0)